utils.py 1.51 KB
Newer Older
1
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
import torch
3
4
from dgl.data import CiteseerGraphDataset, CoraGraphDataset

5
6

def load_data(args):
7
    if args.dataset == "cora":
8
        data = CoraGraphDataset()
9
    elif args.dataset == "citeseer":
10
11
        data = CiteseerGraphDataset()
    else:
12
        raise ValueError("Unknown dataset: {}".format(args.dataset))
13
14
15
16
17
18
    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        g = g.int().to(args.gpu)
19
20
21
22
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    test_mask = g.ndata["test_mask"]
23
24
25
    g = dgl.add_self_loop(g)
    return g, features, labels, train_mask, test_mask, data.num_classes, cuda

26

27
def svd_feature(features, d=200):
28
29
30
    """Get 200-dimensional node features, to avoid curse of dimensionality"""
    if features.shape[1] <= d:
        return features
31
32
33
34
    U, S, VT = torch.svd(features)
    res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
    return res

35

36
def process_classids(labels_temp):
37
38
39
40
    """Reorder the remaining classes with unseen classes removed.
    Input: the label only removing unseen classes
    Output: the label with reordered classes
    """
41
    labeldict = {}
42
    num = 0
43
    for i in labels_temp:
44
45
        labeldict[int(i)] = 1
    labellist = sorted(labeldict)
46
    for label in labellist:
47
48
        labeldict[int(label)] = num
        num = num + 1
49
    for i in range(labels_temp.numel()):
50
51
        labels_temp[i] = labeldict[int(labels_temp[i])]
    return labels_temp