loop.cpp 324 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
#ifndef LOOP_CPP
#define LOOP_CPP

#include <torch/torch.h>

inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
                                                            at::Tensor col) {
  auto mask = row != col;
  return {row.masked_select(mask), col.masked_select(mask)};
}

#endif // LOOP_CPP