"docs/vscode:/vscode.git/clone" did not exist on "6f1d014463496e1f1066e6fa2ffbcf5a537d489c"
Commit 7fa93728 authored by Paul's avatar Paul
Browse files

Add concat kernel for input fusions

parent 3c160a3f
...@@ -74,5 +74,22 @@ __device__ auto concat(Inputs... inputs) ...@@ -74,5 +74,22 @@ __device__ auto concat(Inputs... inputs)
}; };
} }
template <index_int Axis, class... InputPacks>
__device__ auto concat2(InputPacks... input_packs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) {
concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
return start + concat_ends<Axis>(x);
});
});
})(_c<0>, input_packs...);
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP #endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment