main.py 7.04 KB
Newer Older
1
import argparse
2
3
import warnings

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
7
8
9
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
10
11
import torch.optim as optim
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
from model import GRAND
13

14
warnings.filterwarnings("ignore")
15
16
17


def argument():
18
    parser = argparse.ArgumentParser(description="GRAND")
19
20

    # data source params
21
22
23
    parser.add_argument(
        "--dataname", type=str, default="cora", help="Name of dataset."
    )
24
    # cuda params
25
26
27
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
28
    # training params
29
30
31
32
33
34
35
36
37
38
39
40
41
    parser.add_argument(
        "--epochs", type=int, default=200, help="Training epochs."
    )
    parser.add_argument(
        "--early_stopping",
        type=int,
        default=200,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument(
        "--weight_decay", type=float, default=5e-4, help="L2 reg."
    )
42
    # model params
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    parser.add_argument(
        "--hid_dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--dropnode_rate",
        type=float,
        default=0.5,
        help="Dropnode rate (1 - keep probability).",
    )
    parser.add_argument(
        "--input_droprate",
        type=float,
        default=0.0,
        help="dropout rate of input layer",
    )
    parser.add_argument(
        "--hidden_droprate",
        type=float,
        default=0.0,
        help="dropout rate of hidden layer",
    )
    parser.add_argument("--order", type=int, default=8, help="Propagation step")
    parser.add_argument(
        "--sample", type=int, default=4, help="Sampling times of dropnode"
    )
    parser.add_argument(
        "--tem", type=float, default=0.5, help="Sharpening temperature"
    )
    parser.add_argument(
        "--lam",
        type=float,
        default=1.0,
        help="Coefficient of consistency regularization",
    )
    parser.add_argument(
        "--use_bn",
        action="store_true",
        default=False,
        help="Using Batch Normalization",
    )
83
84

    args = parser.parse_args()
85

86
87
    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
88
        args.device = "cuda:{}".format(args.gpu)
89
    else:
90
        args.device = "cpu"
91
92
93

    return args

94

95
96
def consis_loss(logps, temp, lam):
    ps = [th.exp(p) for p in logps]
97
98
99
100
101
102
103
    ps = th.stack(ps, dim=2)

    avg_p = th.mean(ps, dim=2)
    sharp_p = (
        th.pow(avg_p, 1.0 / temp)
        / th.sum(th.pow(avg_p, 1.0 / temp), dim=1, keepdim=True)
    ).detach()
104
105

    sharp_p = sharp_p.unsqueeze(2)
106
    loss = th.mean(th.sum(th.pow(ps - sharp_p, 2), dim=1, keepdim=True))
107
108
109
110

    loss = lam * loss
    return loss

111
112

if __name__ == "__main__":
113
114
115
116
117
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
    args = argument()
    print(args)

118
    if args.dataname == "cora":
119
        dataset = CoraGraphDataset()
120
    elif args.dataname == "citeseer":
121
        dataset = CiteseerGraphDataset()
122
    elif args.dataname == "pubmed":
123
        dataset = PubmedGraphDataset()
124

125
    graph = dataset[0]
126

127
128
129
130
131
132
133
    graph = dgl.add_self_loop(graph)
    device = args.device

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
134
135
    labels = graph.ndata.pop("label").to(device).long()

136
    # Extract node features
137
    feats = graph.ndata.pop("feat").to(device)
138
139
140
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
141
142
143
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")
144
145
146
147
148
149

    train_idx = th.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = th.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = th.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    # Step 2: Create model =================================================================== #
150
151
152
153
154
155
156
157
158
159
160
    model = GRAND(
        n_features,
        args.hid_dim,
        n_classes,
        args.sample,
        args.order,
        args.dropnode_rate,
        args.input_droprate,
        args.hidden_droprate,
        args.use_bn,
    )
161
162
163

    model = model.to(args.device)
    graph = graph.to(args.device)
164

165
166
    # Step 3: Create training components ===================================================== #
    loss_fn = nn.NLLLoss()
167
168
169
    opt = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
170
171
172

    loss_best = np.inf
    acc_best = 0
173

174
175
    # Step 4: training epoches =============================================================== #
    for epoch in range(args.epochs):
176
        """Training"""
177
        model.train()
178

179
180
        loss_sup = 0
        logits = model(graph, feats, True)
181

182
183
184
        # calculate supervised loss
        for k in range(args.sample):
            loss_sup += F.nll_loss(logits[k][train_idx], labels[train_idx])
185
186
187

        loss_sup = loss_sup / args.sample

188
189
        # calculate consistency loss
        loss_consis = consis_loss(logits, args.tem, args.lam)
190

191
        loss_train = loss_sup + loss_consis
192
193
194
        acc_train = th.sum(
            logits[0][train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)
195
196
197
198
199
200

        # backward
        opt.zero_grad()
        loss_train.backward()
        opt.step()

201
        """ Validating """
202
203
204
        model.eval()
        with th.no_grad():
            val_logits = model(graph, feats, False)
205
206
207
208
209

            loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx])
            acc_val = th.sum(
                val_logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)
210
211

            # Print out performance
212
213
214
215
216
217
218
219
220
            print(
                "In epoch {}, Train Acc: {:.4f} | Train Loss: {:.4f} ,Val Acc: {:.4f} | Val Loss: {:.4f}".format(
                    epoch,
                    acc_train,
                    loss_train.item(),
                    acc_val,
                    loss_val.item(),
                )
            )
221
222
223
224
225

            # set early stopping counter
            if loss_val < loss_best or acc_val > acc_best:
                if loss_val < loss_best:
                    best_epoch = epoch
226
                    th.save(model.state_dict(), args.dataname + ".pkl")
227
228
229
230
231
232
                no_improvement = 0
                loss_best = min(loss_val, loss_best)
                acc_best = max(acc_val, acc_best)
            else:
                no_improvement += 1
                if no_improvement == args.early_stopping:
233
                    print("Early stopping.")
234
                    break
235

236
    print("Optimization Finished!")
237
238
239
240
241

    print("Loading {}th epoch".format(best_epoch))
    model.load_state_dict(th.load(args.dataname + ".pkl"))

    """ Testing """
242
243
    model.eval()

244
245
246
247
    test_logits = model(graph, feats, False)
    test_acc = th.sum(
        test_logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)
248

249
    print("Test Acc: {:.4f}".format(test_acc))