common.cuh 317 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
#pragma once

#include <ATen/ATen.h>

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

inline at::Tensor degree(at::Tensor index, int num_nodes) {
  auto zero = at::zeros(index.type(), {num_nodes});
  auto one = at::ones(index.type(), {index.size(0)});
  return zero.scatter_add_(0, index, one);
}