Unverified Commit a275f590 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Loop operator (#853)



Add Loop operator for opset version 13.
Notes: 1) Default max iteration number is 10 if no max iteration number is provided
2) To change the max iter number, a user can set the max_loop_iterations in the onnx_option struct when parsing a model.
3) The returned shape of the scan output is from the max_loop_iterations even the actual loop num is less than that. This issue also applies to other operators like NonZero and NonMaxSuppression. A issue #948 is created to track this and to be resolved later.
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 8b4c69c5
......@@ -122,6 +122,7 @@ register_migraphx_ops(
logical_or
logical_xor
logsoftmax
loop
lrn
lstm
max
......
File mode changed from 100755 to 100644
......@@ -93,6 +93,11 @@ void set_default_dim_value(onnx_options& options, size_t value)
options.default_dim_value = value;
}
void set_default_loop_iterations(onnx_options& options, int64_t value)
{
options.max_loop_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......@@ -843,6 +848,17 @@ migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options
});
}
extern "C" migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value)
{
return migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_default_loop_iterations((onnx_options->object), (value));
});
}
extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{
......
......@@ -243,6 +243,10 @@ migraphx_status migraphx_onnx_options_set_input_parameter_shape(
migraphx_status migraphx_onnx_options_set_default_dim_value(migraphx_onnx_options_t onnx_options,
size_t value);
migraphx_status
migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_options,
int64_t value);
migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
......
......@@ -630,6 +630,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
call(&migraphx_onnx_options_set_default_dim_value, this->get_handle_ptr(), value);
}
/// Set default max iteration number for the loop operator
void set_default_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
}
};
/// Parse an onnx file into a migraphx program
......
......@@ -243,6 +243,11 @@ def onnx_options(h):
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_default_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)',
)
api.add_function('migraphx_parse_onnx',
......
......@@ -10,11 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {
"convert", "get_tuple_elem", "if", "loop"};
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
continue;
if(contains({"convert", "get_tuple_elem"}, ins->name()))
if(contains(skip_op_names, ins->name()))
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
......
......@@ -18,6 +18,8 @@ struct onnx_options
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
};
/// Create a program from an onnx file
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOOP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <migraphx/run_loop.hpp>
#include <migraphx/ranges.hpp>
#include <cmath>
#include <string>
#include <utility>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct loop
{
int64_t max_iterations = 10;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_iterations, "max_iterations"));
}
std::string name() const { return "loop"; }
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
check_shapes{inputs, *this}.standard();
if(mods.size() != 1)
{
MIGRAPHX_THROW("LOOP: operator should have one submodule.");
}
const auto& mod = mods.front();
auto mod_out_shapes = mod->get_output_shapes();
auto dep_param_num = inputs.size() - 2;
// first item of the mod output shapes is condition used in loop,
// which is not needed to compute output shape
mod_out_shapes.erase(mod_out_shapes.begin());
std::vector<shape> ins_out_shapes(mod_out_shapes.begin(),
mod_out_shapes.begin() + dep_param_num);
mod_out_shapes.erase(mod_out_shapes.begin(), mod_out_shapes.begin() + dep_param_num);
for(const auto& out_s : mod_out_shapes)
{
auto lens = out_s.lens();
lens.insert(lens.begin(), max_iterations);
ins_out_shapes.push_back({out_s.type(), lens});
}
return shape(ins_out_shapes);
}
struct ref_loop
{
int64_t max_iterations = 0;
template <class T>
void copy(context&, const argument& src, T& dst) const
{
dst = *src.cast<T>();
}
template <class T>
void copy(context&, T src, const argument& dst) const
{
*dst.cast<T>() = src;
}
void append(const std::vector<argument>& iter_state,
const std::vector<argument>& concatenated_outputs,
int iter) const
{
assert(iter_state.size() == concatenated_outputs.size());
for(auto i : range(iter_state.size()))
{
const auto& iter_stat = iter_state.at(i);
const auto& scan_out = concatenated_outputs.at(i);
auto* in_data = iter_stat.data();
auto* out_data = scan_out.data();
std::size_t out_size = iter_stat.get_shape().bytes();
assert((iter + 1) * out_size <= scan_out.get_shape().bytes());
std::copy(in_data, in_data + out_size, out_data + iter * out_size);
}
}
void set_zero(context&, const std::vector<argument>& concatenated_outputs, int iter) const
{
if(iter >= max_iterations)
return;
for(const auto& out : concatenated_outputs)
{
auto s = out.get_shape();
auto size = s.bytes() / max_iterations;
std::fill(out.data() + iter * size, out.data() + max_iterations * size, 0);
}
}
std::unordered_map<std::string, int> get_output_params(const module&) const { return {}; }
};
argument compute(context& ctx,
const shape& out_shape,
const std::vector<argument>& args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
// wrap up the arguments vector, so ref and gpu impl are the same
auto cpy_args = args;
bool in_cond = args.at(1).at<bool>();
bool cond = in_cond;
int64_t iter = 0;
// insert iter and cond used in the loop
auto s_cond = args.at(1).get_shape();
auto s_iter = args.at(0).get_shape();
cpy_args.push_back({s_iter, &iter});
cpy_args.push_back({s_cond, &cond});
cpy_args.insert(cpy_args.end(), args.begin() + 2, args.end());
// add cond and mod outputs to the argument list
cpy_args.push_back(argument(s_cond));
cpy_args.push_back(argument(out_shape));
// run loop
return run_loop(ref_loop{max_iterations}, ctx, cpy_args, mods, run);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -256,6 +256,18 @@ argument compute_op(const T& x,
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<4>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<3>,
const T& x,
......@@ -313,7 +325,7 @@ argument compute_op(const T& x,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......
......@@ -49,6 +49,7 @@
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
......
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run)
{
std::vector<std::vector<argument>> results;
// process argu lists
auto iter_num = args.at(0).at<int64_t>();
auto cond = args.at(1).at<bool>();
auto input_num = (args.size() - 2) / 2;
auto dep_num = input_num - 2;
module_ref mod = mods.at(0);
auto param_name_shapes = mod->get_parameter_shapes();
auto param_names = mod->get_parameter_names();
std::vector<argument> dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num);
std::vector<argument> dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1);
auto ins_outputs = args.back().get_sub_objects();
dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num);
std::array<std::vector<argument>, 2> loop_carry_deps = {dep0, dep1};
// loop iter argument
std::vector<argument> in_args = {args.at(input_num), dep1.at(0)};
in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num);
std::vector<argument> out_args = dep0;
out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end());
std::vector<argument> scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end());
auto out_param_indices = model.get_output_params(*mod);
int64_t iter = 0;
for(iter = 0; iter < iter_num and cond; ++iter)
{
// copy iter num and cond to device memory
model.copy(ctx, iter, in_args.at(0));
model.copy(ctx, cond, in_args.at(1));
// wrap up the inputs and outputs
std::unordered_map<std::string, argument> params;
int input_index = 0;
for(const auto& name : param_names)
{
auto ps = mod->get_parameter_shape(name);
if(ps == shape{})
{
continue;
}
// it is an input parameter
if(not contains(out_param_indices, name))
{
params[name] = in_args.at(input_index++);
}
else
{
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
}
else
{
params[name] = out_args.at(output_index);
}
}
}
auto mod_args = run(mod, params);
// copy back cond to be used next iteration
model.copy(ctx, mod_args.at(0), cond);
// mod outputs are used as next loop input
std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1);
const auto& dep_out = loop_carry_deps[(iter + 1) % 2];
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());
std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
}
out_args.erase(out_args.begin());
std::copy(in_args.begin() + 2, in_args.end(), out_args.begin());
model.set_zero(ctx, scan_outputs, iter);
return argument(out_args);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -335,7 +335,6 @@ shape module::get_parameter_shape(std::string name) const
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
......@@ -431,7 +430,6 @@ instruction_ref module::validate() const
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) {
return contains(impl->instructions, *in);
});
return !i.valid(impl->instructions.begin(), check_order);
});
}
......
......@@ -63,6 +63,7 @@ struct onnx_parser
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
......
......@@ -21,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
if(options.print_program_on_error)
{
......
......@@ -224,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
instructions[f.name()] = mod->add_literal(parse_tensor(f));
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
instructions[name] = mod->add_parameter(name, s);
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node())
{
std::vector<instruction_ref> args;
......@@ -309,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -242,7 +242,8 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
return generic_eval(smod, ctx, inputs, results, make_trace);
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
......
......@@ -325,12 +325,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
......@@ -338,7 +340,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
......
......@@ -41,6 +41,7 @@ add_library(migraphx_device
device/equal.cpp
device/erf.cpp
device/exp.cpp
device/fill.cpp
device/floor.cpp
device/gather.cpp
device/gelu.cpp
......@@ -138,6 +139,7 @@ add_library(migraphx_gpu
kernel.cpp
lowering.cpp
logsoftmax.cpp
loop.cpp
lrn.cpp
leaky_relu.cpp
mlir_conv.cpp
......@@ -193,6 +195,7 @@ register_migraphx_gpu_ops(hip_
logical_and
logical_or
logical_xor
loop
max
min
mul
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment