Commit 53a8997f authored by Paul's avatar Paul
Browse files

Fix tidy errors

parent f9531026
...@@ -92,7 +92,6 @@ MIGRAPHX_REGISTER_OP(fused_concat); ...@@ -92,7 +92,6 @@ MIGRAPHX_REGISTER_OP(fused_concat);
namespace { namespace {
static unsigned int counter = 0;
struct find_pointwise_concat_pointwise struct find_pointwise_concat_pointwise
{ {
auto matcher() const auto matcher() const
...@@ -124,6 +123,7 @@ struct find_pointwise_concat_pointwise ...@@ -124,6 +123,7 @@ struct find_pointwise_concat_pointwise
[&](auto input) { return input != concat_ins; }); [&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs; std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(), std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(), concat_ins->inputs().end(),
std::back_inserter(module_inputs), std::back_inserter(module_inputs),
......
...@@ -59,20 +59,26 @@ constexpr auto concat_ends(Input) ...@@ -59,20 +59,26 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class Start, class InputPack, class F, class... Ts>
__device__ auto concat_each(index idx, Start start, InputPack input_pack, F f, Ts... ts)
{
return input_pack([&](auto g, auto x, auto... xs) {
return concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
return start + concat_ends<Axis>(x);
});
});
}
template <index_int Axis, class... InputPacks> template <index_int Axis, class... InputPacks>
__device__ auto concat(InputPacks... input_packs) __device__ auto concat(InputPacks... input_packs)
{ {
return [=](auto f, auto... ts) { return [=](auto f, auto... ts) {
auto idx = make_index(); auto idx = make_index();
fold([&](auto start, auto input_pack) { fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) { return concat_each<Axis>(idx, start, input_pack, f, ts...);
return concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
[&](auto i) { z[i] = f(g(x[i], xs[i]...), ys[i]...); });
return start + concat_ends<Axis>(x);
});
});
})(_c<0>, input_packs...); })(_c<0>, input_packs...);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment