utils.py 1.51 KB
Newer Older
1
import torch
2

3
import dgl
4
5
from dgl.data import CiteseerGraphDataset, CoraGraphDataset

6
7

def load_data(args):
8
    if args.dataset == "cora":
9
        data = CoraGraphDataset()
10
    elif args.dataset == "citeseer":
11
12
        data = CiteseerGraphDataset()
    else:
13
        raise ValueError("Unknown dataset: {}".format(args.dataset))
14
15
16
17
18
19
    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        g = g.int().to(args.gpu)
20
21
22
23
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    test_mask = g.ndata["test_mask"]
24
25
26
    g = dgl.add_self_loop(g)
    return g, features, labels, train_mask, test_mask, data.num_classes, cuda

27

28
def svd_feature(features, d=200):
29
30
31
    """Get 200-dimensional node features, to avoid curse of dimensionality"""
    if features.shape[1] <= d:
        return features
32
33
34
35
    U, S, VT = torch.svd(features)
    res = torch.mm(U[:, 0:d], torch.diag(S[0:d]))
    return res

36

37
def process_classids(labels_temp):
38
39
40
41
    """Reorder the remaining classes with unseen classes removed.
    Input: the label only removing unseen classes
    Output: the label with reordered classes
    """
42
    labeldict = {}
43
    num = 0
44
    for i in labels_temp:
45
46
        labeldict[int(i)] = 1
    labellist = sorted(labeldict)
47
    for label in labellist:
48
49
        labeldict[int(label)] = num
        num = num + 1
50
    for i in range(labels_temp.numel()):
51
52
        labels_temp[i] = labeldict[int(labels_temp[i])]
    return labels_temp