degree.cpp 381 Bytes
Newer Older
rusty1s's avatar
new try  
rusty1s committed
1
2
#ifndef DEGREE_INC
#define DEGREE_INC
rusty1s's avatar
rusty1s committed
3
4
5
6
7

#include <torch/torch.h>

inline at::Tensor degree(at::Tensor index, int num_nodes,
                         at::ScalarType scalar_type) {
rusty1s's avatar
cleanup  
rusty1s committed
8
  auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
rusty1s's avatar
rusty1s committed
9
10
11
  auto one = at::full(zero.type(), {index.size(0)}, 1);
  return zero.scatter_add_(0, index, one);
}
rusty1s's avatar
new try  
rusty1s committed
12
13

#endif // DEGREE_INC