Commit fd801fee authored by Paul's avatar Paul
Browse files

Format

parent 449df648
...@@ -935,10 +935,8 @@ auto pointwise_name(const std::string& s) ...@@ -935,10 +935,8 @@ auto pointwise_name(const std::string& s)
{ {
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) { return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front(); module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
return i.name() == s; if(n != 1)
});
if (n != 1)
return false; return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) { return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s; return starts_with(i.name(), "@") or i.name() == s;
...@@ -1028,7 +1026,11 @@ void fuse_ops::apply(module& m) const ...@@ -1028,7 +1026,11 @@ void fuse_ops::apply(module& m) const
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_gemm_pointwise{}, find_commutative_broadcast{}); match::find_matches(m,
find_triadd_layernorm{},
find_gemm_add{},
find_gemm_pointwise{},
find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment