"src/op/operator.h" did not exist on "73a6cb8bfd1f6dfc6197b7ad9253719dd720d681"
Commit 4d50c338 authored by Paul's avatar Paul
Browse files

Fix used_once fusions

parent 60f852b3
...@@ -94,6 +94,7 @@ struct module_pm : module_pass_manager ...@@ -94,6 +94,7 @@ struct module_pm : module_pass_manager
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
trace("Pass: ", p.name());
assert(mod); assert(mod);
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{})) if(enabled(MIGRAPHX_TIME_PASSES{}))
......
...@@ -270,7 +270,7 @@ struct find_dot_mul ...@@ -270,7 +270,7 @@ struct find_dot_mul
auto matcher() const auto matcher() const
{ {
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant()); auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(match::either_arg(0, 1)( auto mul = match::name("mul")(match::used_once(), match::either_arg(0, 1)(
const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z"))); const_broadcast.bind("d"), match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c"))); return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
} }
......
...@@ -770,7 +770,7 @@ struct find_contiguous_pointwise ...@@ -770,7 +770,7 @@ struct find_contiguous_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise"))); return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")(match::used_once())));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -790,7 +790,7 @@ struct find_layernorm_pointwise ...@@ -790,7 +790,7 @@ struct find_layernorm_pointwise
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)( return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm")(match::used_once()).bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -813,7 +813,7 @@ struct find_concat_pointwise ...@@ -813,7 +813,7 @@ struct find_concat_pointwise
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")( return precompile_name("pointwise")(
match::arg(0)(precompile_name("concat").bind("concat"))); match::arg(0)(precompile_name("concat")(match::used_once()).bind("concat")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment