Commit 421ecad6 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents 3d0426e9 7cf05301
......@@ -45,57 +45,48 @@ def shape_type_wrap(p):
p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options')
def compile_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_compile_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
def auto_handle(*args, **kwargs):
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
@api.cwrap('migraphx::onnx_options')
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
return with_handle
@api.cwrap('migraphx::tf_options')
def tf_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_tf_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
@api.handle('migraphx_optimals', 'std::set<size_t>')
def optimals(h):
h.constructor('create',
api.params(ptr='const size_t*', size='size_t'),
fname='migraphx::make_set<size_t>')
def auto_handle(*args, **kwargs):
@api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
def dynamic_dimension(h):
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
h.constructor(
'create_min_max_optimals',
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
h.method('is_fixed', returns='bool', const=True)
h.method('equal',
api.params(x='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle
@api.handle('migraphx_dynamic_dimensions',
'std::vector<migraphx::shape::dynamic_dimension>')
def dynamic_dimensions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape::dynamic_dimension&')
@auto_handle()
......@@ -110,20 +101,29 @@ def shape(h):
lengths='std::vector<size_t>',
strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.constructor(
'create_dynamic',
api.params(type='migraphx::shape::type_t',
dims='std::vector<migraphx::shape::dynamic_dimension>'))
h.method('lengths',
fname='lens',
returns='const std::vector<size_t>&',
const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('dyn_dims',
returns='std::vector<migraphx::shape::dynamic_dimension>',
const=True)
h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True)
h.method('ndim', returns='size_t', const=True)
h.method('equal',
api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('standard', returns='bool', const=True)
h.method('dynamic', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
......@@ -131,6 +131,7 @@ def shape(h):
def argument(h):
h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*'))
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
h.method('shape',
fname='get_shape',
cpp_name='get_shape',
......@@ -326,11 +327,22 @@ def onnx_options(h):
api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)',
)
h.method(
'set_dyn_input_parameter_shape',
api.params(name='const char*',
dims='std::vector<migraphx::shape::dynamic_dimension>'),
invoke='migraphx::set_dyn_input_parameter_shape($@)',
)
h.method(
'set_default_dim_value',
api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)',
)
h.method(
'set_default_dyn_dim_value',
api.params(dd='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::set_default_dyn_dim_value($@)',
)
h.method(
'set_default_loop_iterations',
api.params(value='int64_t'),
......
......@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this;
}
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{
params.push_back({pname, "T" + pname});
......@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else if(with_char(::isdigit)(key[0]))
{
auto i = std::stoul(key);
if(i >= args.size())
MIGRAPHX_THROW("Invalid argument index: " + key);
return args.at(i);
}
else if(v.contains(key))
......@@ -238,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
if(f.params.empty())
impl->fs << delim;
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
......
......@@ -436,11 +436,6 @@ struct compiler
{"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(co.split_single_dyn_dim,
{"--split-single-dyn-dim"},
ap.help("If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
}
......@@ -662,6 +657,26 @@ struct onnx : command<onnx>
}
};
struct tf : command<tf>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all tf operators supported by MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_tf_operators())
std::cout << name << std::endl;
}
}
};
struct main_command
{
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
......
......@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp>
#include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue;
if(ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
......@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
}
}
// Don't create pointwise module if no inputs are detected
if(pointwise_inputs.empty())
continue;
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
......@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{
create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
{
return;
}
for(int i = 0; i < 8; i++)
{
if(not find_pointwise_modules(mpm.get_module()))
......
......@@ -40,9 +40,6 @@ struct compile_options
bool fast_math = true;
bool exhaustive_tune = false;
/// Use the split_single_dyn_dim pass
bool split_single_dyn_dim = false;
tracer trace{};
};
......
......@@ -78,6 +78,7 @@ struct cpp_generator
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
function& add_generic_param(const std::string& pname);
function& unused_param(const std::string& pname);
};
cpp_generator();
......
......@@ -37,6 +37,15 @@ namespace op {
struct dequantizelinear
{
value attributes() const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}};
}
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -45,14 +45,15 @@ struct pointwise
{
MIGRAPHX_THROW("should have one submodule.");
}
auto* pm = mods.front();
auto* pm = mods.front();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("pointwise should have only one output.");
if(inputs.empty())
MIGRAPHX_THROW("pointwise should have at least one input");
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type();
// Scalar output if all inputs are scalar
......
......@@ -38,6 +38,15 @@ namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
value attributes() const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}};
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims().has(2, 3);
......
......@@ -125,7 +125,7 @@ struct select_module
auto ps = param_shapes.at(name);
if(a.get_shape() != ps)
{
assert(ps.bytes() == a.get_shape().bytes());
assert(ps.bytes() <= a.get_shape().bytes());
return std::make_pair(name, a.reshape(ps));
}
else
......
......@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index
std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Map element index to multi-dimensional index
std::vector<std::size_t> multi(std::size_t idx) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
......
......@@ -43,6 +43,8 @@ struct tf_options
/// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, const tf_options& options = tf_options{});
std::vector<std::string> get_tf_operators();
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
*
* See normalize_attribute.hpp for explaining the options.
*/
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
const std::vector<std::size_t>& lens,
Message m)
{
std::vector<int64_t> result(vec);
int64_t n_rank = lens.size();
......@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
}
......@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
}
......@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key();
if(val.contains(key))
{
auto vv = val.at(key).without_key();
auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
......@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result;
op.from_value(val);
val = op.to_value();
......@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
......
......@@ -39,7 +39,19 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REMOVE_LAST_OUTPUT);
static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
namespace onnx {
......@@ -302,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name));
}
else
{
......@@ -508,16 +520,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{
return {shape_type};
}
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
return shape_from_dyn_dims(shape_type, dynamic_dims);
}
shape::type_t get_type(int dtype)
......
......@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(),
std::back_inserter(result),
[&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result;
}
......
......@@ -103,6 +103,7 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override
{
trace("Pass: ", p.name());
assert(mod);
assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{}))
......
......@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "contiguous")
......@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{
if(not literals[i].empty())
{
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
......
......@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx {
migraphx::value to_value(py::kwargs kwargs);
......@@ -94,6 +95,10 @@ void visit_py(T x, F f)
{
f(x.template cast<std::string>());
}
else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
{
f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
}
else
{
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
......@@ -164,7 +169,10 @@ template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
auto strides = s.strides();
assert(s.type() != migraphx::shape::tuple_type);
if(s.dynamic())
MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
auto strides = s.strides();
std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b;
......@@ -176,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<bool>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -185,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -235,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
if(v.contains("dyn_dims"))
{
auto dyn_dims =
migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
v.at("dyn_dims"));
return migraphx::shape(t, dyn_dims);
}
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
......@@ -248,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
.def(py::init<>())
.def(py::init<std::size_t, std::size_t>())
.def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) {
......@@ -282,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
......@@ -433,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
......@@ -447,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
......@@ -457,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
......
......@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std::vector<std::size_t> shape::multi(std::size_t i) const
std::vector<std::size_t> shape::multi(std::size_t idx) const
{
assert(this->standard());
assert(idx < elements());
std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size());
multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{
assert(this->standard());
size_t tidx = idx;
(void)end;
assert(idx < elements());
assert(lens().size() <= (end - start));
std::transform(strides().begin(),
strides().end(),
lens().begin(),
start,
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
for(size_t ii = lens().size() - 1; ii > 0; ii--)
{
*(start + ii) = tidx % lens()[ii];
tidx = tidx / lens()[ii];
}
*start = tidx;
}
bool shape::packed() const
......@@ -709,14 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto v_optimals = x.at("optimals");
std::set<size_t> set_x_optimals =
from_value<std::set<std::size_t>>(x.at("optimals"));
return shape::dynamic_dimension{x_min, x_max, set_x_optimals};
});
std::transform(
v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
return from_value<shape::dynamic_dimension>(x);
});
s = shape{shape::parse_type(t), dyn_dims};
}
......
......@@ -204,6 +204,131 @@ struct find_mul_slice_conv
}
};
struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
......@@ -361,30 +486,118 @@ struct find_inner_broadcast
{
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
[&](instruction_ref i) {
auto input = i->inputs().front();
if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().lens().size() > 1)
return m.insert_instruction(i, make_op("squeeze"), input);
return input;
});
std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
if(i->get_shape().scalar())
return 2;
else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};
......@@ -393,7 +606,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
}
template <class Iterator>
......@@ -412,7 +626,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
}
void apply(module& m, const match::matcher_result& r) const
......@@ -440,6 +655,16 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
......@@ -1260,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
{
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment