Commit 421ecad6 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents 3d0426e9 7cf05301
...@@ -45,57 +45,48 @@ def shape_type_wrap(p): ...@@ -45,57 +45,48 @@ def shape_type_wrap(p):
p.read = 'migraphx::to_shape_type(${name})' p.read = 'migraphx::to_shape_type(${name})'
@api.cwrap('migraphx::compile_options') def auto_handle(*args, **kwargs):
def compile_options_type_wrap(p): def with_handle(f):
if p.returns: return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
p.add_param('migraphx_compile_options *') *args, **kwargs)(f)
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_compile_options(${result})']
else:
p.add_param('migraphx_compile_options *')
p.read = '${name} == nullptr ? migraphx::compile_options{} : migraphx::to_compile_options(*${name})'
@api.cwrap('migraphx::file_options')
def file_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_file_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_file_options(${result})']
else:
p.add_param('migraphx_file_options *')
p.read = '${name} == nullptr ? migraphx::file_options{} : migraphx::to_file_options(*${name})'
@api.cwrap('migraphx::onnx_options') return with_handle
def onnx_options_type_wrap(p):
if p.returns:
p.add_param('migraphx_onnx_options *')
p.bad_param('${name} == nullptr', 'Null pointer')
p.write = ['*${name} = migraphx::to_onnx_options(${result})']
else:
p.add_param('migraphx_onnx_options *')
p.read = '${name} == nullptr ? migraphx::onnx_options{} : migraphx::to_onnx_options(*${name})'
@api.cwrap('migraphx::tf_options') @api.handle('migraphx_optimals', 'std::set<size_t>')
def tf_options_type_wrap(p): def optimals(h):
if p.returns: h.constructor('create',
p.add_param('migraphx_tf_options *') api.params(ptr='const size_t*', size='size_t'),
p.bad_param('${name} == nullptr', 'Null pointer') fname='migraphx::make_set<size_t>')
p.write = ['*${name} = migraphx::to_tf_options(${result})']
else:
p.add_param('migraphx_tf_options *')
p.read = '${name} == nullptr ? migraphx::tf_options{} : migraphx::to_tf_options(*${name})'
def auto_handle(*args, **kwargs): @api.handle('migraphx_dynamic_dimension', 'migraphx::shape::dynamic_dimension')
def dynamic_dimension(h):
h.constructor('create_min_max', api.params(min='size_t', max='size_t'))
h.constructor(
'create_min_max_optimals',
api.params(min='size_t', max='size_t', optimals='std::set<size_t>'))
h.method('is_fixed', returns='bool', const=True)
h.method('equal',
api.params(x='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::equal($@)',
returns='bool',
const=True)
def with_handle(f):
return api.handle('migraphx_' + f.__name__, 'migraphx::' + f.__name__,
*args, **kwargs)(f)
return with_handle @api.handle('migraphx_dynamic_dimensions',
'std::vector<migraphx::shape::dynamic_dimension>')
def dynamic_dimensions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_dynamic_dimension_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_dynamic_dimension_t>')
h.method('size', returns='size_t')
h.method('get',
api.params(idx='size_t'),
fname='at',
cpp_name='operator[]',
returns='const migraphx::shape::dynamic_dimension&')
@auto_handle() @auto_handle()
...@@ -110,20 +101,29 @@ def shape(h): ...@@ -110,20 +101,29 @@ def shape(h):
lengths='std::vector<size_t>', lengths='std::vector<size_t>',
strides='std::vector<size_t>')) strides='std::vector<size_t>'))
h.constructor('create_scalar', api.params(type='migraphx::shape::type_t')) h.constructor('create_scalar', api.params(type='migraphx::shape::type_t'))
h.constructor(
'create_dynamic',
api.params(type='migraphx::shape::type_t',
dims='std::vector<migraphx::shape::dynamic_dimension>'))
h.method('lengths', h.method('lengths',
fname='lens', fname='lens',
returns='const std::vector<size_t>&', returns='const std::vector<size_t>&',
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('dyn_dims',
returns='std::vector<migraphx::shape::dynamic_dimension>',
const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True) h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('ndim', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('dynamic', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True) h.method('index', api.params(i='size_t'), returns='size_t', const=True)
...@@ -131,6 +131,7 @@ def shape(h): ...@@ -131,6 +131,7 @@ def shape(h):
def argument(h): def argument(h):
h.constructor('create', h.constructor('create',
api.params(shape='const migraphx::shape&', buffer='void*')) api.params(shape='const migraphx::shape&', buffer='void*'))
h.constructor('create_empty', api.params(shape='const migraphx::shape&'))
h.method('shape', h.method('shape',
fname='get_shape', fname='get_shape',
cpp_name='get_shape', cpp_name='get_shape',
...@@ -326,11 +327,22 @@ def onnx_options(h): ...@@ -326,11 +327,22 @@ def onnx_options(h):
api.params(name='const char*', dims='std::vector<size_t>'), api.params(name='const char*', dims='std::vector<size_t>'),
invoke='migraphx::set_input_parameter_shape($@)', invoke='migraphx::set_input_parameter_shape($@)',
) )
h.method(
'set_dyn_input_parameter_shape',
api.params(name='const char*',
dims='std::vector<migraphx::shape::dynamic_dimension>'),
invoke='migraphx::set_dyn_input_parameter_shape($@)',
)
h.method( h.method(
'set_default_dim_value', 'set_default_dim_value',
api.params(value='size_t'), api.params(value='size_t'),
invoke='migraphx::set_default_dim_value($@)', invoke='migraphx::set_default_dim_value($@)',
) )
h.method(
'set_default_dyn_dim_value',
api.params(dd='const migraphx::shape::dynamic_dimension&'),
invoke='migraphx::set_default_dyn_dim_value($@)',
)
h.method( h.method(
'set_default_loop_iterations', 'set_default_loop_iterations',
api.params(value='int64_t'), api.params(value='int64_t'),
......
...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module ...@@ -106,6 +106,11 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::unused_param(const std::string& pname)
{
body.insert(0, "(void)" + pname + ";\n");
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname) cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{ {
params.push_back({pname, "T" + pname}); params.push_back({pname, "T" + pname});
...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op, ...@@ -174,6 +179,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
else if(with_char(::isdigit)(key[0])) else if(with_char(::isdigit)(key[0]))
{ {
auto i = std::stoul(key); auto i = std::stoul(key);
if(i >= args.size())
MIGRAPHX_THROW("Invalid argument index: " + key);
return args.at(i); return args.at(i);
} }
else if(v.contains(key)) else if(v.contains(key))
...@@ -238,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f) ...@@ -238,6 +245,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '('; char delim = '(';
if(f.params.empty())
impl->fs << delim;
for(auto&& p : f.params) for(auto&& p : f.params)
{ {
impl->fs << delim << p.type << " " << p.name; impl->fs << delim << p.type << " " << p.name;
......
...@@ -436,11 +436,6 @@ struct compiler ...@@ -436,11 +436,6 @@ struct compiler
{"--exhaustive-tune"}, {"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"), ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true)); ap.set_value(true));
ap(co.split_single_dyn_dim,
{"--split-single-dyn-dim"},
ap.help("If there is a single non-fixed dynamic dimension in the model, then split to "
"static submodules"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16)); ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8)); ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
...@@ -662,6 +657,26 @@ struct onnx : command<onnx> ...@@ -662,6 +657,26 @@ struct onnx : command<onnx>
} }
}; };
struct tf : command<tf>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all tf operators supported by MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_tf_operators())
std::cout << name << std::endl;
}
}
};
struct main_command struct main_command
{ {
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow, static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <iterator> #include <iterator>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -67,13 +69,13 @@ static void create_pointwise_modules(module_pass_manager& mpm)
continue; continue;
if(ins->get_operator().name() == "layout") if(ins->get_operator().name() == "layout")
continue; continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++)); auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass(); pm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs; std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0; std::size_t i = 0;
for(auto input : ins->inputs()) for(auto input : ins->inputs())
{ {
if(contains(param_map, input)) if(contains(param_map, input))
...@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm) ...@@ -92,6 +94,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
} }
} }
// Don't create pointwise module if no inputs are detected
if(pointwise_inputs.empty())
continue;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(), std::transform(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
...@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const ...@@ -188,6 +194,10 @@ void fuse_pointwise::apply(module_pass_manager& mpm) const
{ {
create_pointwise_modules(mpm); create_pointwise_modules(mpm);
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
if(enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}))
{
return;
}
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
if(not find_pointwise_modules(mpm.get_module())) if(not find_pointwise_modules(mpm.get_module()))
......
...@@ -40,9 +40,6 @@ struct compile_options ...@@ -40,9 +40,6 @@ struct compile_options
bool fast_math = true; bool fast_math = true;
bool exhaustive_tune = false; bool exhaustive_tune = false;
/// Use the split_single_dyn_dim pass
bool split_single_dyn_dim = false;
tracer trace{}; tracer trace{};
}; };
......
...@@ -78,6 +78,7 @@ struct cpp_generator ...@@ -78,6 +78,7 @@ struct cpp_generator
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m); function& set_generic_types(const module& m);
function& add_generic_param(const std::string& pname); function& add_generic_param(const std::string& pname);
function& unused_param(const std::string& pname);
}; };
cpp_generator(); cpp_generator();
......
...@@ -37,6 +37,15 @@ namespace op { ...@@ -37,6 +37,15 @@ namespace op {
struct dequantizelinear struct dequantizelinear
{ {
value attributes() const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}};
}
std::string name() const { return "dequantizelinear"; } std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -45,14 +45,15 @@ struct pointwise ...@@ -45,14 +45,15 @@ struct pointwise
{ {
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
} }
auto* pm = mods.front(); auto* pm = mods.front();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("pointwise should have only one output.");
if(inputs.empty())
MIGRAPHX_THROW("pointwise should have at least one input");
auto pnames = pm->get_parameter_names(); auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end()); std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims(); check_shapes{inputs, *this}.has(pnames.size()).same_dims();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type(); auto type = pm->get_output_shapes().front().type();
// Scalar output if all inputs are scalar // Scalar output if all inputs are scalar
......
...@@ -38,6 +38,15 @@ namespace op { ...@@ -38,6 +38,15 @@ namespace op {
struct quantizelinear struct quantizelinear
{ {
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
value attributes() const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}};
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims().has(2, 3); check_shapes{inputs, *this}.same_dims().has(2, 3);
......
...@@ -125,7 +125,7 @@ struct select_module ...@@ -125,7 +125,7 @@ struct select_module
auto ps = param_shapes.at(name); auto ps = param_shapes.at(name);
if(a.get_shape() != ps) if(a.get_shape() != ps)
{ {
assert(ps.bytes() == a.get_shape().bytes()); assert(ps.bytes() <= a.get_shape().bytes());
return std::make_pair(name, a.reshape(ps)); return std::make_pair(name, a.reshape(ps));
} }
else else
......
...@@ -222,11 +222,15 @@ struct shape ...@@ -222,11 +222,15 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const; /// Map element index to multi-dimensional index
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; std::vector<std::size_t> multi(std::size_t idx) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no /// Map element index to multi-dimensional index and put them them into location provided by
/// padding /// pointers
void multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with
/// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
......
...@@ -43,6 +43,8 @@ struct tf_options ...@@ -43,6 +43,8 @@ struct tf_options
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, const tf_options& options = tf_options{}); program parse_tf(const std::string& name, const tf_options& options = tf_options{});
std::vector<std::string> get_tf_operators();
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const std::vector<std::size_t>& lens,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = lens.size();
...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -39,7 +39,19 @@ ...@@ -39,7 +39,19 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REMOVE_LAST_OUTPUT); static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
namespace onnx { namespace onnx {
...@@ -302,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -302,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else if(map_dyn_input_dims.count(name) > 0) else if(map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)}; s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name));
} }
else else
{ {
...@@ -508,16 +520,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -508,16 +520,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
return {shape_type}; return {shape_type};
} }
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) return shape_from_dyn_dims(shape_type, dynamic_dims);
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -103,6 +103,7 @@ struct module_pm : module_pass_manager ...@@ -103,6 +103,7 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
trace("Pass: ", p.name());
assert(mod); assert(mod);
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{})) if(enabled(MIGRAPHX_TIME_PASSES{}))
......
...@@ -27,11 +27,14 @@ ...@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const ...@@ -85,6 +88,19 @@ void propagate_constant::apply(module& m) const
{ {
if(not literals[i].empty()) if(not literals[i].empty())
{ {
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l); m.replace_instruction(const_instrs_vec[i], l);
......
...@@ -62,6 +62,7 @@ namespace py = pybind11; ...@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \ PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx { namespace migraphx {
migraphx::value to_value(py::kwargs kwargs); migraphx::value to_value(py::kwargs kwargs);
...@@ -94,6 +95,10 @@ void visit_py(T x, F f) ...@@ -94,6 +95,10 @@ void visit_py(T x, F f)
{ {
f(x.template cast<std::string>()); f(x.template cast<std::string>());
} }
else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
{
f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
}
else else
{ {
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!"); MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
...@@ -164,7 +169,10 @@ template <class T> ...@@ -164,7 +169,10 @@ template <class T>
py::buffer_info to_buffer_info(T& x) py::buffer_info to_buffer_info(T& x)
{ {
migraphx::shape s = x.get_shape(); migraphx::shape s = x.get_shape();
auto strides = s.strides(); assert(s.type() != migraphx::shape::tuple_type);
if(s.dynamic())
MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
auto strides = s.strides();
std::transform( std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); }); strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b; py::buffer_info b;
...@@ -176,7 +184,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -176,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(), b = py::buffer_info(x.data(),
as.size(), as.size(),
py::format_descriptor<bool>::format(), py::format_descriptor<bool>::format(),
s.lens().size(), s.ndim(),
s.lens(), s.lens(),
strides); strides);
} }
...@@ -185,7 +193,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -185,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(), b = py::buffer_info(x.data(),
as.size(), as.size(),
py::format_descriptor<decltype(as())>::format(), py::format_descriptor<decltype(as())>::format(),
s.lens().size(), s.ndim(),
s.lens(), s.lens(),
strides); strides);
} }
...@@ -235,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -235,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float")); auto t = migraphx::shape::parse_type(v.get("type", "float"));
if(v.contains("dyn_dims"))
{
auto dyn_dims =
migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
v.at("dyn_dims"));
return migraphx::shape(t, dyn_dims);
}
auto lens = v.get<std::size_t>("lens", {1}); auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides")) if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>()); return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
...@@ -248,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -248,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string) .def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted) .def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard) .def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar) .def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{}) .def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{}) .def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
.def(py::init<>())
.def(py::init<std::size_t, std::size_t>())
.def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) { .def(py::init([](py::buffer b) {
...@@ -282,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -282,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref"); py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module") py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
...@@ -433,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -433,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx", "parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value, unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
...@@ -447,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -447,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 0,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10);
...@@ -457,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -457,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer", "parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
unsigned int default_dim_value, unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error) { bool print_program_on_error) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options); return migraphx::parse_onnx_buffer(onnx_buffer, options);
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 0,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false); py::arg("print_program_on_error") = false);
......
...@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const ...@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const std::vector<std::size_t> shape::multi(std::size_t idx) const
{ {
assert(this->standard()); assert(idx < elements());
std::vector<std::size_t> indices(lens().size()); std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size()); multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices; return indices;
} }
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{ {
assert(this->standard()); size_t tidx = idx;
(void)end; (void)end;
assert(idx < elements());
assert(lens().size() <= (end - start)); assert(lens().size() <= (end - start));
std::transform(strides().begin(), for(size_t ii = lens().size() - 1; ii > 0; ii--)
strides().end(), {
lens().begin(), *(start + ii) = tidx % lens()[ii];
start, tidx = tidx / lens()[ii];
[&](std::size_t stride, std::size_t len) { }
assert(len > 0 and stride > 0); *start = tidx;
return (i / stride) % len;
});
} }
bool shape::packed() const bool shape::packed() const
...@@ -709,14 +706,10 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -709,14 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{ {
auto v_dd = v.at("dynamic_dimensions"); auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size()); std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) { std::transform(
auto x_min = x.at("min").template to<size_t>(); v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
auto x_max = x.at("max").template to<size_t>(); return from_value<shape::dynamic_dimension>(x);
auto v_optimals = x.at("optimals"); });
std::set<size_t> set_x_optimals =
from_value<std::set<std::size_t>>(x.at("optimals"));
return shape::dynamic_dimension{x_min, x_max, set_x_optimals};
});
s = shape{shape::parse_type(t), dyn_dims}; s = shape{shape::parse_type(t), dyn_dims};
} }
......
...@@ -204,6 +204,131 @@ struct find_mul_slice_conv ...@@ -204,6 +204,131 @@ struct find_mul_slice_conv
} }
}; };
struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ****************************** // ******************************
// a * (x + b) => a * x + a * b // a * (x + b) => a * x + a * b
// ****************************** // ******************************
...@@ -361,30 +486,118 @@ struct find_inner_broadcast ...@@ -361,30 +486,118 @@ struct find_inner_broadcast
{ {
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); } auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto broadcasts = ins->inputs(); auto broadcasts = ins->inputs();
if(broadcasts.empty()) if(broadcasts.empty())
return; return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
broadcasts.end(), broadcasts.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [&](instruction_ref i) {
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { auto input = i->inputs().front();
return i->get_shape() != inputs.front()->get_shape() and if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().elements() != 1; i->get_shape().lens().size() > 1)
})) return m.insert_instruction(i, make_op("squeeze"), input);
return; return input;
});
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar(); std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
}); if(i->get_shape().scalar())
if(b_it == broadcasts.end()) return 2;
b_it = broadcasts.begin(); else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
} }
}; };
...@@ -393,7 +606,8 @@ struct find_concat_op ...@@ -393,7 +606,8 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once())); match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -412,7 +626,8 @@ struct find_concat_op ...@@ -412,7 +626,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op) static bool is_valid_op(const operation& op)
{ {
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -440,6 +655,16 @@ struct find_concat_op ...@@ -440,6 +655,16 @@ struct find_concat_op
op = b; op = b;
iaxis = 0; iaxis = 0;
} }
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats; std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++) for(std::size_t i = 0; i < x->inputs().size(); i++)
...@@ -1260,12 +1485,15 @@ void simplify_algebra::apply(module& m) const ...@@ -1260,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
find_conv_dot_horiz_fusion{}, find_conv_dot_horiz_fusion{},
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{}, find_mul_add{},
find_unit_ops{}, find_unit_ops{},
find_neg_unit_ops{}, find_neg_unit_ops{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment