semisupervised.py 8.82 KB
Newer Older
1
2
import argparse

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
import dgl

5
6
7
import numpy as np
import torch as th
import torch.nn.functional as F
8
from dgl.data import QM9EdgeDataset
9
10
from dgl.data.utils import Subset
from dgl.dataloading import GraphDataLoader
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
from model import InfoGraphS
12
13
14


def argument():
15
    parser = argparse.ArgumentParser(description="InfoGraphS")
16
17

    # data source params
18
19
20
21
22
23
    parser.add_argument(
        "--target", type=str, default="mu", help="Choose regression task"
    )
    parser.add_argument(
        "--train_num", type=int, default=5000, help="Size of training set"
    )
24
25

    # training params
26
27
28
29
30
31
32
33
34
35
36
37
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="Training epochs."
    )
    parser.add_argument(
        "--batch_size", type=int, default=20, help="Training batch size."
    )
    parser.add_argument(
        "--val_batch_size", type=int, default=100, help="Validation batch size."
    )
38

39
40
41
42
    parser.add_argument(
        "--lr", type=float, default=0.001, help="Learning rate."
    )
    parser.add_argument("--wd", type=float, default=0, help="Weight decay.")
43
44

    # model params
45
46
47
48
49
50
    parser.add_argument(
        "--hid_dim", type=int, default=64, help="Hidden layer dimensionality"
    )
    parser.add_argument(
        "--reg", type=float, default=0.001, help="Regularization coefficient"
    )
51
52
53
54
55

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
56
        args.device = "cuda:{}".format(args.gpu)
57
    else:
58
        args.device = "cpu"
59
60
61

    return args

62

63
64
class DenseQM9EdgeDataset(QM9EdgeDataset):
    def __getitem__(self, idx):
65
66
        r"""Get graph and label by index

67
68
69
70
        Parameters
        ----------
        idx : int
            Item index
71

72
73
74
75
        Returns
        -------
        dgl.DGLGraph
           The graph contains:
76

77
78
79
           - ``ndata['pos']``: the coordinates of each atom
           - ``ndata['attr']``: the features of each atom
           - ``edata['edge_attr']``: the features of each bond
80

81
82
83
        Tensor
            Property values of molecular graphs
        """
84
85
86
87

        pos = self.node_pos[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
        src = self.src[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
        dst = self.dst[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
88
89

        g = dgl.graph((src, dst))
90
91
92
93
94
95
96
97
98

        g.ndata["pos"] = th.tensor(pos).float()
        g.ndata["attr"] = th.tensor(
            self.node_attr[self.n_cumsum[idx] : self.n_cumsum[idx + 1]]
        ).float()
        g.edata["edge_attr"] = th.tensor(
            self.edge_attr[self.ne_cumsum[idx] : self.ne_cumsum[idx + 1]]
        ).float()

99
        label = th.tensor(self.targets[idx][self.label_keys]).float()
100

101
102
103
104
        n_nodes = g.num_nodes()
        row = th.arange(n_nodes)
        col = th.arange(n_nodes)

105
        row = row.view(-1, 1).repeat(1, n_nodes).view(-1)
106
107
108
109
110
111
        col = col.repeat(n_nodes)

        src = g.edges()[0]
        dst = g.edges()[1]

        idx = src * n_nodes + dst
112
        size = list(g.edata["edge_attr"].size())
113
        size[0] = n_nodes * n_nodes
114
        edge_attr = g.edata["edge_attr"].new_zeros(size)
115

116
        edge_attr[idx] = g.edata["edge_attr"]
117

118
        pos = g.ndata["pos"]
119
120
        dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)

121
        new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim=-1)
122

123
124
125
        graph = dgl.graph((row, col))
        graph.ndata["attr"] = g.ndata["attr"]
        graph.edata["edge_attr"] = new_edge_attr
126
        graph = graph.remove_self_loop()
127

128
        return graph, label
129

130

131
def collate(samples):
132
    """collate function for building graph dataloader"""
133
134
135
136
137

    # generate batched graphs and labels
    graphs, targets = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    batched_targets = th.Tensor(targets)
138

139
140
141
    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
142
143
144

    batched_graph.ndata["graph_id"] = graph_id

145
146
    return batched_graph, batched_targets

147

148
149
150
def evaluate(model, loader, num, device):
    error = 0
    for graphs, targets in loader:
151
152
153
        graphs = graphs.to(device)

        nfeat, efeat = graphs.ndata["attr"], graphs.edata["edge_attr"]
154
155
156
157
        targets = targets.to(device)
        error += (model(graphs, nfeat, efeat) - targets).abs().sum().item()

    error = error / num
158

159
160
    return error

161
162

if __name__ == "__main__":
163
164
165
166
167
    # Step 1: Prepare graph data   ===================================== #
    args = argument()
    label_keys = [args.target]
    print(args)

168
169
    dataset = DenseQM9EdgeDataset(label_keys=label_keys)

170
    # Train/Val/Test Splitting
171
    N = dataset.targets.shape[0]
172
173
174
175
176
177
178
179
    all_idx = np.arange(N)
    np.random.shuffle(all_idx)

    val_num = 10000
    test_num = 10000

    val_idx = all_idx[:val_num]
    test_idx = all_idx[val_num : val_num + test_num]
180
181
182
    train_idx = all_idx[
        val_num + test_num : val_num + test_num + args.train_num
    ]
183
184
185
186
187

    train_data = Subset(dataset, train_idx)
    val_data = Subset(dataset, val_idx)
    test_data = Subset(dataset, test_idx)

188
    unsup_idx = all_idx[val_num + test_num :]
189
190
191
    unsup_data = Subset(dataset, unsup_idx)

    # generate supervised training dataloader and unsupervised training dataloader
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    train_loader = GraphDataLoader(
        train_data,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    unsup_loader = GraphDataLoader(
        unsup_data,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
207
208

    # generate validation & testing dataloader
209
210
211
212
213
214
215
    val_loader = GraphDataLoader(
        val_data,
        batch_size=args.val_batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
216

217
218
219
220
221
222
223
    test_loader = GraphDataLoader(
        test_data,
        batch_size=args.val_batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
224

225
    print("======== target = {} ========".format(args.target))
226

227
    in_dim = dataset[0][0].ndata["attr"].shape[1]
228
229
230
231
232
233

    # Step 2: Create model =================================================================== #
    model = InfoGraphS(in_dim, args.hid_dim)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
234
235
236
    optimizer = th.optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.wd
    )
237
    scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(
238
        optimizer, mode="min", factor=0.7, patience=5, min_lr=0.000001
239
240
241
    )

    # Step 4: training epochs =============================================================== #
242
243
244
    best_val_error = float("inf")
    test_error = float("inf")

245
    for epoch in range(args.epochs):
246
        """Training"""
247
        model.train()
248
        lr = scheduler.optimizer.param_groups[0]["lr"]
249
250
251
252
253
254
255
256
257
258
259
260

        iteration = 0
        sup_loss_all = 0
        unsup_loss_all = 0
        consis_loss_all = 0

        for sup_data, unsup_data in zip(train_loader, unsup_loader):
            sup_graph, sup_target = sup_data
            unsup_graph, _ = unsup_data

            sup_graph = sup_graph.to(args.device)
            unsup_graph = unsup_graph.to(args.device)
261
262
263
264
265
266
267
268
269
270

            sup_nfeat, sup_efeat = (
                sup_graph.ndata["attr"],
                sup_graph.edata["edge_attr"],
            )
            unsup_nfeat, unsup_efeat, unsup_graph_id = (
                unsup_graph.ndata["attr"],
                unsup_graph.edata["edge_attr"],
                unsup_graph.ndata["graph_id"],
            )
271
272
273
274
275
276

            sup_target = sup_target
            sup_target = sup_target.to(args.device)

            optimizer.zero_grad()

277
278
279
280
281
282
            sup_loss = F.mse_loss(
                model(sup_graph, sup_nfeat, sup_efeat), sup_target
            )
            unsup_loss, consis_loss = model.unsup_forward(
                unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id
            )
283
284
285
286
287
288
289
290
291
292
293

            loss = sup_loss + unsup_loss + args.reg * consis_loss

            loss.backward()

            sup_loss_all += sup_loss.item()
            unsup_loss_all += unsup_loss.item()
            consis_loss_all += consis_loss.item()

            optimizer.step()

294
295
296
297
298
        print(
            "Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}".format(
                epoch, sup_loss_all, unsup_loss_all, consis_loss_all
            )
        )
299
300
301
302
303

        model.eval()

        val_error = evaluate(model, val_loader, val_num, args.device)
        scheduler.step(val_error)
304

305
306
307
308
        if val_error < best_val_error:
            best_val_error = val_error
            test_error = evaluate(model, test_loader, test_num, args.device)

309
310
311
312
313
        print(
            "Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}".format(
                epoch, lr, val_error, test_error
            )
        )