Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "topk"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return shape({s_val, s_ind});
}
template <class T, class Compare>
struct heap_vector
{
std::vector<T> data;
Compare compare;
heap_vector(const std::vector<T>& val, Compare comp) : data(val), compare(std::move(comp))
{
std::make_heap(data.begin(), data.end(), compare);
}
void try_push(T val)
{
if(not compare(val, data.front()))
return;
std::pop_heap(data.begin(), data.end(), compare);
data.back() = val;
std::push_heap(data.begin(), data.end(), compare);
}
std::vector<T> sort()
{
auto sorted_data = data;
std::sort_heap(sorted_data.begin(), sorted_data.end(), compare);
return sorted_data;
}
};
template <class T, class Compare>
heap_vector<T, Compare> make_heap(std::vector<T> val, Compare compare) const
{
return {std::move(val), std::move(compare)};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
auto out_s = vec_ss.front();
auto comp_lens = in_s.lens();
auto axis_dim = comp_lens[axis];
// compute shape
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
visit_all(res_val, args.front())([&](auto out_val, auto input) {
auto* out_ind = res_ind.cast<int64_t>();
par_for(comp_s.elements(), [&](auto i) {
auto idx = comp_s.multi(i);
std::vector<std::size_t> indices(k);
std::iota(indices.begin(), indices.end(), 0);
auto comp = [&](auto i1, auto i2) {
auto idx1 = idx;
auto idx2 = idx;
idx1[axis] = i1;
idx2[axis] = i2;
return this->largest
? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)])
: std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]);
};
auto hp = this->make_heap(indices, comp);
for(std::size_t ii = indices.size(); ii < axis_dim; ++ii)
{
hp.try_push(ii);
}
auto sorted_indices = hp.sort();
auto out_idx = idx;
auto in_idx = idx;
for(auto j : range(sorted_indices.size()))
{
out_idx[axis] = j;
in_idx[axis] = sorted_indices[j];
out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)];
out_ind[out_s.index(out_idx)] = sorted_indices[j];
}
});
});
return argument({res_val, res_ind});
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -21,7 +21,7 @@ struct transpose
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
return pack(f(self.dims, "permutation"));
}
std::string name() const { return "transpose"; }
......@@ -32,31 +32,23 @@ struct transpose
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
auto tuned_dims = dims;
// if not perm provided, reverse the dims
if(tuned_dims.empty())
{
tuned_dims.resize(input_lens.size());
std::iota(tuned_dims.begin(), tuned_dims.end(), 0);
std::reverse(tuned_dims.begin(), tuned_dims.end());
}
if(tuned_dims.size() != input_lens.size())
if(dims.size() != input_lens.size())
{
MIGRAPHX_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(tuned_dims.size());
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), tuned_dims.begin()))
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPHX_THROW("Invalid permutation");
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
for(std::size_t i = 0; i < output_lens.size(); i++)
{
output_lens[i] = input_lens[tuned_dims[i]];
output_strides[i] = input_strides[tuned_dims[i]];
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {t, output_lens, output_strides};
}
......
......@@ -41,7 +41,11 @@ struct unary : op_name<Derived>
{
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0);
if(s.broadcasted())
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -256,6 +256,18 @@ argument compute_op(const T& x,
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<4>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<3>,
const T& x,
......@@ -313,7 +325,7 @@ argument compute_op(const T& x,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......
......@@ -49,6 +49,7 @@
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
......@@ -95,11 +96,13 @@
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
......@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
......@@ -31,13 +34,37 @@ struct pass
#else
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(module & m) const;
* void apply(module_pass_manager & mpm) const;
* void apply(program & p) const;
* };
*
......@@ -112,10 +139,10 @@ struct pass
return (*this).private_detail_te_get_handle().name();
}
void apply(module& m) const
void apply(module_pass_manager& mpm) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(m);
(*this).private_detail_te_get_handle().apply(mpm);
}
void apply(program& p) const
......@@ -137,22 +164,24 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(module& m) const = 0;
virtual void apply(program& p) const = 0;
virtual std::string name() const = 0;
virtual void apply(module_pass_manager& mpm) const = 0;
virtual void apply(program& p) const = 0;
};
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, module& m)
-> decltype(private_detail_te_self.apply(m))
static auto
private_detail_te_default_apply(char, T&& private_detail_te_self, module_pass_manager& mpm)
-> decltype(private_detail_te_self.apply(mpm))
{
private_detail_te_self.apply(m);
private_detail_te_self.apply(mpm);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, module& m)
static void
private_detail_te_default_apply(float, T&& private_detail_te_self, module_pass_manager& mpm)
{
migraphx::nop(private_detail_te_self, m);
migraphx::detail::module_pass_manager_apply(private_detail_te_self, mpm);
}
template <class T>
......@@ -198,10 +227,10 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); }
void apply(module& m) const override
void apply(module_pass_manager& mpm) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, m);
private_detail_te_default_apply(char(0), private_detail_te_value, mpm);
}
void apply(program& p) const override
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
#include <migraphx/pass.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager
{
module_pass_manager() = default;
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual void run_pass(const pass& p) = 0;
protected:
virtual ~module_pass_manager() {}
};
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
......
......@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* quantize a program to fp16
*/
struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -169,6 +169,12 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it);
}
template <class Range, class Iterator, class F>
void transform(Range&& r, Iterator it, F f)
{
std::transform(r.begin(), r.end(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
......
#ifndef MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#define MIGRAPHX_GUARD_RTGLIB_RUN_LOOP_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/module.hpp>
#include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class LoopModel, class T>
argument run_loop(const LoopModel& model,
T& ctx,
std::vector<argument> args,
const std::vector<module_ref>& mods,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run)
{
std::vector<std::vector<argument>> results;
// process argu lists
auto iter_num = args.at(0).at<int64_t>();
auto cond = args.at(1).at<bool>();
auto input_num = (args.size() - 2) / 2;
auto dep_num = input_num - 2;
module_ref mod = mods.at(0);
auto param_name_shapes = mod->get_parameter_shapes();
auto param_names = mod->get_parameter_names();
std::vector<argument> dep0(args.begin() + input_num + 1, args.begin() + 2 * input_num);
std::vector<argument> dep1(args.begin() + 2 * input_num, args.begin() + 2 * input_num + 1);
auto ins_outputs = args.back().get_sub_objects();
dep1.insert(dep1.end(), ins_outputs.begin(), ins_outputs.begin() + dep_num);
std::array<std::vector<argument>, 2> loop_carry_deps = {dep0, dep1};
// loop iter argument
std::vector<argument> in_args = {args.at(input_num), dep1.at(0)};
in_args.insert(in_args.end(), args.begin() + 2, args.begin() + input_num);
std::vector<argument> out_args = dep0;
out_args.insert(out_args.end(), ins_outputs.begin() + dep_num, ins_outputs.end());
std::vector<argument> scan_outputs(ins_outputs.begin() + dep_num, ins_outputs.end());
auto out_param_indices = model.get_output_params(*mod);
int64_t iter = 0;
for(iter = 0; iter < iter_num and cond; ++iter)
{
// copy iter num and cond to device memory
model.copy(ctx, iter, in_args.at(0));
model.copy(ctx, cond, in_args.at(1));
// wrap up the inputs and outputs
std::unordered_map<std::string, argument> params;
int input_index = 0;
for(const auto& name : param_names)
{
auto ps = mod->get_parameter_shape(name);
if(ps == shape{})
{
continue;
}
// it is an input parameter
if(not contains(out_param_indices, name))
{
params[name] = in_args.at(input_index++);
}
else
{
auto output_index = out_param_indices[name];
if(output_index > dep_num)
{
const auto& arg = out_args.at(output_index);
assert((iter + 1) * ps.bytes() <= arg.get_shape().bytes());
params[name] = argument(ps, arg.data() + iter * ps.bytes());
}
else
{
params[name] = out_args.at(output_index);
}
}
}
auto mod_args = run(mod, params);
// copy back cond to be used next iteration
model.copy(ctx, mod_args.at(0), cond);
// mod outputs are used as next loop input
std::copy(mod_args.begin(), mod_args.begin() + dep_num + 1, in_args.begin() + 1);
const auto& dep_out = loop_carry_deps[(iter + 1) % 2];
std::copy(dep_out.begin(), dep_out.end(), out_args.begin());
std::vector<argument> mod_scan_outs(mod_args.begin() + 1 + dep_num, mod_args.end());
model.append(mod_scan_outs, scan_outputs, iter);
}
out_args.erase(out_args.begin());
std::copy(in_args.begin() + 2, in_args.end(), out_args.begin());
model.set_zero(ctx, scan_outputs, iter);
return argument(out_args);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs)
*/
struct simplify_qdq
{
std::string name() const { return "simplify_qdq"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -393,6 +393,31 @@ struct value
return result;
}
template <class To>
To get(const std::string& pkey, const To& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to<To>();
}
template <class To>
std::vector<To> get(const std::string& pkey, const std::vector<To>& default_value) const
{
const auto* v = find(pkey);
if(v == this->end())
return default_value;
return v->to_vector<To>();
}
template <class To>
std::vector<To> get(const std::string& pkey,
const std::initializer_list<To>& default_value) const
{
return get<std::vector<To>>(pkey, default_value);
}
friend bool operator==(const value& x, const value& y);
friend bool operator!=(const value& x, const value& y);
friend bool operator<(const value& x, const value& y);
......
......@@ -125,9 +125,10 @@ void module::assign(const module& m)
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto order = any_cast<builtin::param>(ins->get_operator()).order;
auto s = ins->get_shape();
copy_ins =
impl->insert(impl->instructions.end(), {builtin::param{name}, std::move(s), {}});
copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
......@@ -334,7 +335,6 @@ shape module::get_parameter_shape(std::string name) const
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
......@@ -430,7 +430,6 @@ instruction_ref module::validate() const
bool check_order = std::all_of(inputs.begin(), inputs.end(), [&](auto in) {
return contains(impl->instructions, *in);
});
return !i.valid(impl->instructions.begin(), check_order);
});
}
......
......@@ -63,6 +63,7 @@ struct onnx_parser
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
......
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream>
#include <fstream>
#include <unordered_map>
......@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
if(options.print_program_on_error)
{
......@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
return parse_onnx_from(options, data, size);
}
std::vector<std::string> get_onnx_operators() { return onnx::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -85,7 +85,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
......@@ -224,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
instructions[f.name()] = mod->add_literal(parse_tensor(f));
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
instructions[name] = mod->add_parameter(name, s);
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node())
{
std::vector<instruction_ref> args;
......@@ -309,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
......@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment