degree.py 506 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch.autograd import Variable
rusty1s's avatar
rusty1s committed
3
4
5
6

from .new import new


rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
def degree(index, num_nodes=None, out=None):
    num_nodes = index.max() + 1 if num_nodes is None else num_nodes
    out = index.new().float() if out is None else out
    index = index if torch.is_tensor(out) else Variable(index)

    if torch.is_tensor(out):
        out.resize_(num_nodes)
rusty1s's avatar
rusty1s committed
14
    else:
rusty1s's avatar
rusty1s committed
15
        out.data.resize_(num_nodes)
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
    one = new(out, index.size(0)).fill_(1)
    return out.fill_(0).scatter_add_(0, index, one)