degree.cpp 341 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "degree.h"
rusty1s's avatar
rusty1s committed
2
3
4
5
6

#include <torch/torch.h>

inline at::Tensor degree(at::Tensor index, int num_nodes,
                         at::ScalarType scalar_type) {
rusty1s's avatar
cleanup  
rusty1s committed
7
  auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
rusty1s's avatar
rusty1s committed
8
9
10
  auto one = at::full(zero.type(), {index.size(0)}, 1);
  return zero.scatter_add_(0, index, one);
}