Commit 7fa93728 authored by Paul's avatar Paul
Browse files

Add concat kernel for input fusions

parent 3c160a3f
......@@ -74,5 +74,22 @@ __device__ auto concat(Inputs... inputs)
};
}
template <index_int Axis, class... InputPacks>
__device__ auto concat2(InputPacks... input_packs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) {
concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
return start + concat_ends<Axis>(x);
});
});
})(_c<0>, input_packs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment