"server/vscode:/vscode.git/clone" did not exist on "2ad895a6cc530474cae7e24ace1e463018172d0e"
train.py 4.23 KB
Newer Older
Aymen Waheb's avatar
Aymen Waheb committed
1
2
3
4
5
import argparse, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
6
7
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
Aymen Waheb's avatar
Aymen Waheb committed
8
9
10
import dgl
from appnp import APPNP

11

Aymen Waheb's avatar
Aymen Waheb committed
12
13
14
15
16
17
18
19
20
21
def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

22

Aymen Waheb's avatar
Aymen Waheb committed
23
24
def main(args):
    # load and preprocess dataset
25
26
27
28
29
30
    if args.dataset == 'cora':
        data = CoraGraphDataset()
    elif args.dataset == 'citeseer':
        data = CiteseerGraphDataset()
    elif args.dataset == 'pubmed':
        data = PubmedGraphDataset()
31
    else:
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        raise ValueError('Unknown dataset: {}'.format(args.dataset))

    g = data[0]
    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        g = g.to(args.gpu)

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
Aymen Waheb's avatar
Aymen Waheb committed
46
47
48
49
50
51
52
53
54
55
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
Zihao Ye's avatar
Zihao Ye committed
56
57
58
           train_mask.int().sum().item(),
           val_mask.int().sum().item(),
           test_mask.int().sum().item()))
Aymen Waheb's avatar
Aymen Waheb committed
59
60
61

    n_edges = g.number_of_edges()
    # add self loop
62
63
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
64

Aymen Waheb's avatar
Aymen Waheb committed
65
66
67
68
69
70
    # create APPNP model
    model = APPNP(g,
                  in_feats,
                  args.hidden_sizes,
                  n_classes,
                  F.relu,
71
72
                  args.in_drop,
                  args.edge_drop,
Aymen Waheb's avatar
Aymen Waheb committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
                  args.alpha,
                  args.k)

    if cuda:
        model.cuda()
    loss_fcn = torch.nn.CrossEntropyLoss()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        logits = model(features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        acc = evaluate(model, features, labels, val_mask)
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
104
105
              "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                            acc, n_edges / np.mean(dur) / 1000))
Aymen Waheb's avatar
Aymen Waheb committed
106
107
108
109
110
111
112
113
114

    print()
    acc = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='APPNP')
    register_data_args(parser)
115
116
117
118
    parser.add_argument("--in-drop", type=float, default=0.5,
                        help="input feature dropout")
    parser.add_argument("--edge-drop", type=float, default=0.5,
                        help="edge propagation dropout")
Aymen Waheb's avatar
Aymen Waheb committed
119
    parser.add_argument("--gpu", type=int, default=-1,
120
                        help="gpu")
Aymen Waheb's avatar
Aymen Waheb committed
121
    parser.add_argument("--lr", type=float, default=1e-2,
122
                        help="learning rate")
Aymen Waheb's avatar
Aymen Waheb committed
123
    parser.add_argument("--n-epochs", type=int, default=200,
124
                        help="number of training epochs")
Aymen Waheb's avatar
Aymen Waheb committed
125
    parser.add_argument("--hidden_sizes", type=int, nargs='+', default=[64],
126
                        help="hidden unit sizes for appnp")
Aymen Waheb's avatar
Aymen Waheb committed
127
    parser.add_argument("--k", type=int, default=10,
128
                        help="Number of propagation steps")
Aymen Waheb's avatar
Aymen Waheb committed
129
    parser.add_argument("--alpha", type=float, default=0.1,
130
                        help="Teleport Probability")
Aymen Waheb's avatar
Aymen Waheb committed
131
    parser.add_argument("--weight-decay", type=float, default=5e-4,
132
                        help="Weight for L2 loss")
Aymen Waheb's avatar
Aymen Waheb committed
133
134
135
136
    args = parser.parse_args()
    print(args)

    main(args)