"...git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "ee4fd16f2c02a643e70c5393f7bb27cfda58671f"
Commit eb186040 authored by Paul's avatar Paul
Browse files

Add concat pointwise fusions

parent 6570087f
...@@ -41,11 +41,15 @@ struct precompile_op ...@@ -41,11 +41,15 @@ struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1; std::size_t additional_args = 1;
bool ignore_modules = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op"), f(self.additional_args, "additional_args")); return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules")
);
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
...@@ -54,6 +58,8 @@ struct precompile_op ...@@ -54,6 +58,8 @@ struct precompile_op
{ {
// Pop off additional args // Pop off additional args
inputs.resize(inputs.size() - additional_args); inputs.resize(inputs.size() - additional_args);
if (ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods); return op.compute_shape(inputs, mods);
} }
......
...@@ -1186,6 +1186,34 @@ struct find_layernorm_pointwise ...@@ -1186,6 +1186,34 @@ struct find_layernorm_pointwise
} }
}; };
struct find_concat_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
precompile_name("concat").bind("concat")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat = r.instructions["concat"];
auto* pm = ins->module_inputs().front();
if(not concat->module_inputs().empty())
return;
auto inputs = concat->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size()}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
...@@ -1209,6 +1237,7 @@ void fuse_ops::apply(module& m) const ...@@ -1209,6 +1237,7 @@ void fuse_ops::apply(module& m) const
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment