Commit 2f268bc2 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents f75c5a38 aa7ff911
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -19,22 +19,22 @@ struct module; ...@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn struct rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const; void apply(module& m) const;
private: private:
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const; void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(module& prog, instruction_ref ins) const; void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -44,9 +44,9 @@ struct rewrite_rnn ...@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators // for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const; void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -55,24 +55,23 @@ struct rewrite_rnn ...@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const; bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog, instruction_ref replace_last_hs_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog, void replace_last_cell_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
instruction_ref last_cell_output, instruction_ref last_cell_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
std::size_t std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog, instruction_ref pad_hidden_states(module& m,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const; instruction_ref hs) const;
......
...@@ -19,7 +19,7 @@ struct schedule ...@@ -19,7 +19,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{}) ...@@ -50,7 +50,6 @@ auto to_value_impl(rank<2>, const T& x) -> decltype(x.begin(), x.end(), value{})
value result = value::array{}; value result = value::array{};
for(auto&& y : x) for(auto&& y : x)
{ {
auto e = to_value(y);
result.insert(to_value(y)); result.insert(to_value(y));
} }
return result; return result;
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -120,10 +120,8 @@ struct tensor_view ...@@ -120,10 +120,8 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// cppcheck-suppress functionConst
iterator begin() { return {0, {this}}; } iterator begin() { return {0, {this}}; }
// cppcheck-suppress functionConst
iterator end() { return {this->size(), {this}}; } iterator end() { return {this->size(), {this}}; }
const_iterator begin() const { return {0, {this}}; } const_iterator begin() const { return {0, {this}}; }
......
...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -168,7 +168,6 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
// cppcheck-suppress uninitvar
if(out_error != nullptr) if(out_error != nullptr)
*out_error = error; *out_error = error;
return error <= threshold; return error <= threshold;
......
...@@ -5,20 +5,41 @@ namespace migraphx { ...@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
for(auto&& x : v) for_each([&](const auto& key, const auto& x) {
{ if(not w.contains(key))
w.at(x.get_key()) = x.without_key(); // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
} MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w); op.from_value(w);
return op; return op;
} }
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl struct module_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
...@@ -555,8 +557,14 @@ instruction_ref module::find_dangling_reference() const ...@@ -555,8 +557,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx) void module::finalize(context& ctx)
{ {
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx); ins->finalize(ctx);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
...@@ -731,7 +739,6 @@ std::unordered_map<instruction_ref, std::string> ...@@ -731,7 +739,6 @@ std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{ {
os << "migraphx::module p;" << std::endl; os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0; unsigned long seed = 0;
names = this->print( names = this->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
......
...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{ {
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
// clang-format off
return {{"Abs", "abs"}, return {{"Abs", "abs"},
{"Acos", "acos"}, {"Acos", "acos"},
{"Acosh", "acosh"}, {"Acosh", "acosh"},
...@@ -27,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -27,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"}, {"Flatten", "flatten"},
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"}, {"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
...@@ -37,8 +39,6 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -37,8 +39,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
...@@ -47,6 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -47,6 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Tan", "tan"}, {"Tan", "tan"},
{"Tanh", "tanh"}, {"Tanh", "tanh"},
{"Not", "not"}}; {"Not", "not"}};
// clang-format on
} }
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -9,6 +10,9 @@ namespace onnx { ...@@ -9,6 +10,9 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean> struct parse_mean : op_parser<parse_mean>
{ {
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; } std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors /// Calculates the element-wise mean of n>=1 input tensors
...@@ -24,14 +28,29 @@ struct parse_mean : op_parser<parse_mean> ...@@ -24,14 +28,29 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal( auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}}); migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) { if(contains(float_types, args[0]->get_shape().type()))
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation {
data_i = info.add_broadcastable_binary_op("div", data_i, divisor); return std::accumulate(args.begin() + 1,
args.end(),
if(data_i != args[0]) info.add_broadcastable_binary_op("div", args[0], divisor),
return info.add_broadcastable_binary_op("add", mean, data_i); [&](auto mean, auto data_i) {
return data_i; // Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
}
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
}); });
return info.add_broadcastable_binary_op("div", sum, divisor);
}
} }
}; };
......
...@@ -19,7 +19,9 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -19,7 +19,9 @@ struct parse_pooling : op_parser<parse_pooling>
return {{"AveragePool", "average"}, return {{"AveragePool", "average"},
{"GlobalAveragePool", "average"}, {"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"}, {"GlobalMaxPool", "max"},
{"MaxPool", "max"}}; {"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}};
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
...@@ -27,14 +29,16 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -27,14 +29,16 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name; std::string mode = opd.op_name;
if(mode != "max" && mode != "average") if(not contains(mode_map, mode))
{ {
MIGRAPHX_THROW("onnx pooling mode must be \"max\" or \"average\""); MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
} }
operation op = make_op( operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
"pooling",
{{"mode", mode == "average" ? op::pooling_mode::average : op::pooling_mode::max}});
value values = op.to_value(); value values = op.to_value();
auto l0 = args[0]; auto l0 = args[0];
auto in_lens = l0->get_shape().lens(); auto in_lens = l0->get_shape().lens();
...@@ -74,6 +78,12 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -74,6 +78,12 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths"); kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
} }
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
...@@ -118,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -118,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end()); std::vector<int64_t> orig_padding = paddings;
std::vector<int64_t> slice_start; std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end; std::vector<int64_t> slice_end;
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct parse_reversesequence : op_parser<parse_reversesequence>
{
std::vector<op_desc> operators() const { return {{"ReverseSequence"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int batch_axis = 1;
if(contains(info.attributes, "batch_axis"))
{
batch_axis = info.attributes.at("batch_axis").i();
}
if(batch_axis != 0 and batch_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1");
}
int time_axis = 0;
if(contains(info.attributes, "time_axis"))
{
time_axis = info.attributes.at("time_axis").i();
}
if(time_axis != 0 and time_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1");
}
if(time_axis == batch_axis)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same");
}
auto input = args[0];
auto input_lens = input->get_shape().lens();
if(input_lens.size() < 2)
{
MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2");
}
std::vector<int64_t> sequence_lens;
if(args.size() == 2)
{
migraphx::argument seq_lens_arg = args.back()->eval();
check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens");
seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "sequence_lens"))
{
literal s = parser.parse_value(info.attributes.at("sequence_lens"));
s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); });
}
auto batch_size = input_lens[batch_axis];
auto time_size = input_lens[time_axis];
// this condition may still work if sequence_len's shape was incorrect
if(sequence_lens.size() != batch_size)
{
MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape");
}
instruction_ref ret;
auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) {
return info.add_instruction(make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b, t_start}},
{"ends", {b + 1, t_end}}}),
input);
};
for(int b = 0; b < batch_size; ++b)
{
instruction_ref s0;
if(sequence_lens[b] > 1)
{
s0 = add_slice(b, 0, sequence_lens[b]);
s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0);
// if reversed less than whole batch, concat rest of batch
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, sequence_lens[b], time_size);
s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1);
}
}
else
{ // cases where nothing changes
s0 = add_slice(b, 0, time_size);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatter : op_parser<parse_scatter>
{
std::vector<op_desc> operators() const { return {{"ScatterElements"}, {"Scatter"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
operation op;
std::string op_name = "scatter_none";
int axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(contains(info.attributes, "reduction"))
{
std::string reduction_att(info.attributes.at("reduction").s());
// check for a valid reduction attribute. We have an operator for each one.
if(not contains({"none", "add", "mul"}, reduction_att))
MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name = std::string("scatter_") + reduction_att;
}
op = migraphx::make_op(op_name, {{"axis", axis}});
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze> ...@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto op = parser.load(opd.op_name, info); auto op = parser.load(opd.op_name, info);
std::vector<int64_t> axes;
if(args.size() == 2) if(args.size() == 2)
{ {
auto arg_axes = args.at(1)->eval(); auto arg_axes = args.at(1)->eval();
check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!"); check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!");
std::vector<int64_t> axes;
arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); }); arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); });
op = assign_axes(op, axes); op = assign_axes(op, axes);
} }
......
...@@ -15,7 +15,7 @@ std::ostream& operator<<(std::ostream& os, pooling_mode v) ...@@ -15,7 +15,7 @@ std::ostream& operator<<(std::ostream& os, pooling_mode v)
{ {
// the strings for the enum are the same as the values used for onnx parsing // the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf // but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max"}; static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)]; os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os; return os;
} }
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const void memory_coloring::apply(module& m) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&m, allocation_op, verify);
opt.run(); opt.run();
} }
} }
......
...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
int ec = 0; int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{})) if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl; std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) { auto closer = [&](FILE* stream) {
auto status = pclose(stream); auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe) if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data()); std_out(buffer.data());
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
void propagate_constant::apply(module& p) const bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{ {
for(auto i : iterator_for(p)) std::unordered_set<instruction_ref> const_instrs;
{ auto last = std::prev(m.end());
if(i->name() != "@literal")
continue; // Find instructions that can be evaluated to a literal
if(i->outputs().empty()) for(auto i : iterator_for(m))
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
if(child->name() == "@literal" or skip_propogate(child))
{ {
self(child); if(is_const(i) and i != last)
continue; continue;
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
} }
auto r = child->eval();
if(not r.empty()) // Compute literals in parallel
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
std::vector<argument> literals(const_instrs_vec.size());
par_for(const_instrs_vec.size(), 1, [&](const auto i) {
literals[i] = const_instrs_vec[i]->eval();
});
// Replace instructions in m
for(size_t i = 0; i < const_instrs_vec.size(); i++)
{ {
assert(r.get_shape() == child->get_shape()); if(not literals[i].empty())
auto l = p.add_literal(r.get_shape(), r.data()); {
self(p.replace_instruction(child, l)); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
} auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
} }
})(i);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment