train_lightning_unsupervised.py 8.59 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
5
6
import argparse
import glob
import os
import sys
import time

7
import dgl
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
9
import dgl.function as fn
import dgl.nn.pytorch as dglnn
10
11
12
13
14
15
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
16
from model import compute_acc_unsupervised as compute_acc, SAGE
17
18
19

from negative_sampler import NegativeSampler
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
20
21
22
23
24
25

from pytorch_lightning.callbacks import Callback, ModelCheckpoint

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import inductive_split, load_ogb, load_reddit

26
27
28
29

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
30
31
32
            pos_graph.ndata["h"] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            pos_score = pos_graph.edata["score"]
33
        with neg_graph.local_scope():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
34
35
36
            neg_graph.ndata["h"] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            neg_score = neg_graph.edata["score"]
37
38

        score = th.cat([pos_score, neg_score])
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
39
40
41
        label = th.cat(
            [th.ones_like(pos_score), th.zeros_like(neg_score)]
        ).long()
42
43
44
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45

46
class SAGELightning(LightningModule):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
47
48
49
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, lr
    ):
50
51
        super().__init__()
        self.save_hyperparameters()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
52
53
54
        self.module = SAGE(
            in_feats, n_hidden, n_classes, n_layers, activation, dropout
        )
55
56
57
58
59
60
61
62
        self.lr = lr
        self.loss_fcn = CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        input_nodes, pos_graph, neg_graph, mfgs = batch
        mfgs = [mfg.int().to(device) for mfg in mfgs]
        pos_graph = pos_graph.to(device)
        neg_graph = neg_graph.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
63
64
        batch_inputs = mfgs[0].srcdata["features"]
        batch_labels = mfgs[-1].dstdata["labels"]
65
66
        batch_pred = self.module(mfgs, batch_inputs)
        loss = self.loss_fcn(batch_pred, pos_graph, neg_graph)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
68
69
        self.log(
            "train_loss", loss, prog_bar=True, on_step=False, on_epoch=True
        )
70
71
72
73
74
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, output_nodes, mfgs = batch
        mfgs = [mfg.int().to(device) for mfg in mfgs]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
76
        batch_inputs = mfgs[0].srcdata["features"]
        batch_labels = mfgs[-1].dstdata["labels"]
77
78
79
80
81
82
83
84
85
        batch_pred = self.module(mfgs, batch_inputs)
        return batch_pred

    def configure_optimizers(self):
        optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


class DataModule(LightningDataModule):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
86
87
88
89
90
91
92
93
94
    def __init__(
        self,
        dataset_name,
        data_cpu=False,
        fan_out=[10, 25],
        device=th.device("cpu"),
        batch_size=1000,
        num_workers=4,
    ):
95
        super().__init__()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
96
        if dataset_name == "reddit":
97
98
            g, n_classes = load_reddit()
            n_edges = g.num_edges()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
102
103
            reverse_eids = th.cat(
                [th.arange(n_edges // 2, n_edges), th.arange(0, n_edges // 2)]
            )
        elif dataset_name == "ogbn-products":
            g, n_classes = load_ogb("ogbn-products")
104
105
106
107
108
            n_edges = g.num_edges()
            # The reverse edge of edge 0 in OGB products dataset is 1.
            # The reverse edge of edge 2 is 3.  So on so forth.
            reverse_eids = th.arange(n_edges) ^ 1
        else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
109
            raise ValueError("unknown dataset")
110

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
114
115
        train_nid = th.nonzero(g.ndata["train_mask"], as_tuple=True)[0]
        val_nid = th.nonzero(g.ndata["val_mask"], as_tuple=True)[0]
        test_nid = th.nonzero(
            ~(g.ndata["train_mask"] | g.ndata["val_mask"]), as_tuple=True
        )[0]
116

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
118
119
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [int(_) for _ in fan_out]
        )
120

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
        dataloader_device = th.device("cpu")
122
123
124
125
        if not data_cpu:
            train_nid = train_nid.to(device)
            val_nid = val_nid.to(device)
            test_nid = test_nid.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
126
            g = g.formats(["csc"])
127
128
129
130
            g = g.to(device)
            dataloader_device = device

        self.g = g
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
131
132
133
134
135
        self.train_nid, self.val_nid, self.test_nid = (
            train_nid,
            val_nid,
            test_nid,
        )
136
137
138
139
        self.sampler = sampler
        self.device = dataloader_device
        self.batch_size = batch_size
        self.num_workers = num_workers
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
140
        self.in_feats = g.ndata["features"].shape[1]
141
142
143
144
        self.n_classes = n_classes
        self.reverse_eids = reverse_eids

    def train_dataloader(self):
145
        sampler = dgl.dataloading.as_edge_prediction_sampler(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
            self.sampler,
            exclude="reverse_id",
148
            reverse_eids=self.reverse_eids,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
149
150
151
152
            negative_sampler=NegativeSampler(
                self.g, args.num_negs, args.neg_share
            ),
        )
153
        return dgl.dataloading.DataLoader(
154
155
            self.g,
            np.arange(self.g.num_edges()),
156
            sampler,
157
158
159
160
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
161
162
            num_workers=self.num_workers,
        )
163
164

    def val_dataloader(self):
165
        # Note that the validation data loader is a DataLoader
166
        # as we want to evaluate all the node embeddings.
167
        return dgl.dataloading.DataLoader(
168
169
170
171
172
173
174
            self.g,
            np.arange(self.g.num_nodes()),
            self.sampler,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
175
176
            num_workers=self.num_workers,
        )
177
178
179
180
181
182


class UnsupervisedClassification(Callback):
    def on_validation_epoch_start(self, trainer, pl_module):
        self.val_outputs = []

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
183
184
185
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
186
187
188
189
190
        self.val_outputs.append(outputs)

    def on_validation_epoch_end(self, trainer, pl_module):
        node_emb = th.cat(self.val_outputs, 0)
        g = trainer.datamodule.g
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
191
        labels = g.ndata["labels"]
192
        f1_micro, f1_macro = compute_acc(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
193
194
195
196
197
198
199
200
            node_emb,
            labels,
            trainer.datamodule.train_nid,
            trainer.datamodule.val_nid,
            trainer.datamodule.test_nid,
        )
        pl_module.log("val_f1_micro", f1_micro)

201

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
202
if __name__ == "__main__":
203
204
    argparser = argparse.ArgumentParser("multi-gpu training")
    argparser.add_argument("--gpu", type=int, default=0)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    argparser.add_argument("--dataset", type=str, default="reddit")
    argparser.add_argument("--num-epochs", type=int, default=20)
    argparser.add_argument("--num-hidden", type=int, default=16)
    argparser.add_argument("--num-layers", type=int, default=2)
    argparser.add_argument("--num-negs", type=int, default=1)
    argparser.add_argument(
        "--neg-share",
        default=False,
        action="store_true",
        help="sharing neg nodes for positive nodes",
    )
    argparser.add_argument("--fan-out", type=str, default="10,25")
    argparser.add_argument("--batch-size", type=int, default=10000)
    argparser.add_argument("--log-every", type=int, default=20)
    argparser.add_argument("--eval-every", type=int, default=1000)
    argparser.add_argument("--lr", type=float, default=0.003)
    argparser.add_argument("--dropout", type=float, default=0.5)
    argparser.add_argument(
        "--num-workers",
        type=int,
        default=0,
        help="Number of sampling processes. Use 0 for no extra process.",
    )
228
229
230
    args = argparser.parse_args()

    if args.gpu >= 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
231
        device = th.device("cuda:%d" % args.gpu)
232
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
233
        device = th.device("cpu")
234
235

    datamodule = DataModule(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
236
237
238
239
240
241
242
        args.dataset,
        True,
        [int(_) for _ in args.fan_out.split(",")],
        device,
        args.batch_size,
        args.num_workers,
    )
243
    model = SAGELightning(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
244
245
246
247
248
249
250
251
        datamodule.in_feats,
        args.num_hidden,
        datamodule.n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
        args.lr,
    )
252
253
254

    # Train
    unsupervised_callback = UnsupervisedClassification()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
255
256
257
258
259
260
261
262
    checkpoint_callback = ModelCheckpoint(monitor="val_f1_micro", save_top_k=1)
    trainer = Trainer(
        gpus=[args.gpu] if args.gpu != -1 else None,
        max_epochs=args.num_epochs,
        val_check_interval=1000,
        callbacks=[checkpoint_callback, unsupervised_callback],
        num_sanity_val_steps=0,
    )
263
    trainer.fit(model, datamodule=datamodule)