Commit 53a8997f authored by Paul's avatar Paul
Browse files

Fix tidy errors

parent f9531026
......@@ -92,7 +92,6 @@ MIGRAPHX_REGISTER_OP(fused_concat);
namespace {
static unsigned int counter = 0;
struct find_pointwise_concat_pointwise
{
auto matcher() const
......@@ -124,6 +123,7 @@ struct find_pointwise_concat_pointwise
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs;
static unsigned int counter = 0;
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
......
......@@ -59,12 +59,9 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>;
}
template <index_int Axis, class... InputPacks>
__device__ auto concat(InputPacks... input_packs)
template <index_int Axis, class Start, class InputPack, class F, class... Ts>
__device__ auto concat_each(index idx, Start start, InputPack input_pack, F f, Ts... ts)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input_pack) {
return input_pack([&](auto g, auto x, auto... xs) {
return concat_slices<Axis>(x, start, ts...)([&](auto z, auto... ys) {
idx.global_stride(x.get_shape().elements(),
......@@ -73,6 +70,15 @@ __device__ auto concat(InputPacks... input_packs)
return start + concat_ends<Axis>(x);
});
});
}
template <index_int Axis, class... InputPacks>
__device__ auto concat(InputPacks... input_packs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input_pack) {
return concat_each<Axis>(idx, start, input_pack, f, ts...);
})(_c<0>, input_packs...);
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment