main.py 7.04 KB
Newer Older
1
import argparse
2
3
import warnings

4
5
6
7
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
8
9
import torch.optim as optim
from model import GRAND
10
11

import dgl
12
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
13

14
warnings.filterwarnings("ignore")
15
16
17
18


def argument():

19
    parser = argparse.ArgumentParser(description="GRAND")
20
21

    # data source params
22
23
24
    parser.add_argument(
        "--dataname", type=str, default="cora", help="Name of dataset."
    )
25
    # cuda params
26
27
28
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
29
    # training params
30
31
32
33
34
35
36
37
38
39
40
41
42
    parser.add_argument(
        "--epochs", type=int, default=200, help="Training epochs."
    )
    parser.add_argument(
        "--early_stopping",
        type=int,
        default=200,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument(
        "--weight_decay", type=float, default=5e-4, help="L2 reg."
    )
43
    # model params
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    parser.add_argument(
        "--hid_dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--dropnode_rate",
        type=float,
        default=0.5,
        help="Dropnode rate (1 - keep probability).",
    )
    parser.add_argument(
        "--input_droprate",
        type=float,
        default=0.0,
        help="dropout rate of input layer",
    )
    parser.add_argument(
        "--hidden_droprate",
        type=float,
        default=0.0,
        help="dropout rate of hidden layer",
    )
    parser.add_argument("--order", type=int, default=8, help="Propagation step")
    parser.add_argument(
        "--sample", type=int, default=4, help="Sampling times of dropnode"
    )
    parser.add_argument(
        "--tem", type=float, default=0.5, help="Sharpening temperature"
    )
    parser.add_argument(
        "--lam",
        type=float,
        default=1.0,
        help="Coefficient of consistency regularization",
    )
    parser.add_argument(
        "--use_bn",
        action="store_true",
        default=False,
        help="Using Batch Normalization",
    )
84
85

    args = parser.parse_args()
86

87
88
    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
89
        args.device = "cuda:{}".format(args.gpu)
90
    else:
91
        args.device = "cpu"
92
93
94

    return args

95

96
97
def consis_loss(logps, temp, lam):
    ps = [th.exp(p) for p in logps]
98
99
100
101
102
103
104
    ps = th.stack(ps, dim=2)

    avg_p = th.mean(ps, dim=2)
    sharp_p = (
        th.pow(avg_p, 1.0 / temp)
        / th.sum(th.pow(avg_p, 1.0 / temp), dim=1, keepdim=True)
    ).detach()
105
106

    sharp_p = sharp_p.unsqueeze(2)
107
    loss = th.mean(th.sum(th.pow(ps - sharp_p, 2), dim=1, keepdim=True))
108
109
110
111

    loss = lam * loss
    return loss

112
113

if __name__ == "__main__":
114
115
116
117
118
119

    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
    args = argument()
    print(args)

120
    if args.dataname == "cora":
121
        dataset = CoraGraphDataset()
122
    elif args.dataname == "citeseer":
123
        dataset = CiteseerGraphDataset()
124
    elif args.dataname == "pubmed":
125
        dataset = PubmedGraphDataset()
126

127
    graph = dataset[0]
128

129
130
131
132
133
134
135
    graph = dgl.add_self_loop(graph)
    device = args.device

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
136
137
    labels = graph.ndata.pop("label").to(device).long()

138
    # Extract node features
139
    feats = graph.ndata.pop("feat").to(device)
140
141
142
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
143
144
145
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
146
147
148
149
150
151

    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    # Step 2: Create model =================================================================== #
152
153
154
155
156
157
158
159
160
161
162
    model = GRAND(
        n_features,
        args.hid_dim,
        n_classes,
        args.sample,
        args.order,
        args.dropnode_rate,
        args.input_droprate,
        args.hidden_droprate,
        args.use_bn,
    )
163
164
165

    model = model.to(args.device)
    graph = graph.to(args.device)
166

167
168
    # Step 3: Create training components ===================================================== #
    loss_fn = nn.NLLLoss()
169
170
171
    opt = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
172
173
174

    loss_best = np.inf
    acc_best = 0
175

176
177
178
    # Step 4: training epoches =============================================================== #
    for epoch in range(args.epochs):

179
        """Training"""
180
        model.train()
181

182
183
        loss_sup = 0
        logits = model(graph, feats, True)
184

185
186
187
        # calculate supervised loss
        for k in range(args.sample):
            loss_sup += F.nll_loss(logits[k][train_idx], labels[train_idx])
188
189
190

        loss_sup = loss_sup / args.sample

191
192
        # calculate consistency loss
        loss_consis = consis_loss(logits, args.tem, args.lam)
193

194
        loss_train = loss_sup + loss_consis
195
196
197
        acc_train = th.sum(
            logits[0][train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
198
199
200
201
202
203

        # backward
        opt.zero_grad()
        loss_train.backward()
        opt.step()

204
        """ Validating """
205
206
        model.eval()
        with th.no_grad():
207

208
            val_logits = model(graph, feats, False)
209
210
211
212
213

            loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx])
            acc_val = th.sum(
                val_logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
214
215

            # Print out performance
216
217
218
219
220
221
222
223
224
            print(
                "In epoch {}, Train Acc: {:.4f} | Train Loss: {:.4f} ,Val Acc: {:.4f} | Val Loss: {:.4f}".format(
                    epoch,
                    acc_train,
                    loss_train.item(),
                    acc_val,
                    loss_val.item(),
                )
            )
225
226
227
228
229

            # set early stopping counter
            if loss_val < loss_best or acc_val > acc_best:
                if loss_val < loss_best:
                    best_epoch = epoch
230
                    th.save(model.state_dict(), args.dataname + ".pkl")
231
232
233
234
235
236
                no_improvement = 0
                loss_best = min(loss_val, loss_best)
                acc_best = max(acc_val, acc_best)
            else:
                no_improvement += 1
                if no_improvement == args.early_stopping:
237
                    print("Early stopping.")
238
                    break
239

240
    print("Optimization Finished!")
241
242
243
244
245

    print("Loading {}th epoch".format(best_epoch))
    model.load_state_dict(th.load(args.dataname + ".pkl"))

    """ Testing """
246
247
    model.eval()

248
249
250
251
    test_logits = model(graph, feats, False)
    test_acc = th.sum(
        test_logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
252

253
    print("Test Acc: {:.4f}".format(test_acc))