node_classification.py 6.65 KB
Newer Older
1
2
3
import glob
import os

4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
8
import torchmetrics.functional as MF
9
10
11
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
12
13
14
15
16
17
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy

import dgl
import dgl.nn.pytorch as dglnn

18
19
20
21
22
23

class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
24
25
26
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
45
46
47
48
        g.ndata["h"] = g.ndata["feat"]
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
            1, prefetch_node_feats=["h"]
        )
49
        dataloader = dgl.dataloading.DataLoader(
50
51
52
53
54
55
56
57
58
59
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=num_workers,
            persistent_workers=(num_workers > 0),
        )
60
61
62
63
64
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
            y = torch.zeros(
65
66
67
68
                g.num_nodes(),
                self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
                device=buffer_device,
            )
69
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
70
                x = blocks[0].srcdata["h"]
71
72
73
74
75
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                y[output_nodes] = h.to(buffer_device)
76
            g.ndata["h"] = y
77
78
79
80
        return y

    def training_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
81
82
        x = blocks[0].srcdata["feat"]
        y = blocks[-1].dstdata["label"]
83
84
85
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
86
87
88
89
90
91
92
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )
93
94
95
96
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
97
98
        x = blocks[0].srcdata["feat"]
        y = blocks[-1].dstdata["label"]
99
100
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
101
102
103
104
105
106
107
108
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=True,
            sync_dist=True,
        )
109
110

    def configure_optimizers(self):
111
112
113
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, weight_decay=5e-4
        )
114
115
        return optimizer

116

117
class DataModule(LightningDataModule):
118
119
120
    def __init__(
        self, graph, train_idx, val_idx, fanouts, batch_size, n_classes
    ):
121
122
123
        super().__init__()

        sampler = dgl.dataloading.NeighborSampler(
124
125
            fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"]
        )
126
127
128
129
130

        self.g = graph
        self.train_idx, self.val_idx = train_idx, val_idx
        self.sampler = sampler
        self.batch_size = batch_size
131
        self.in_feats = graph.ndata["feat"].shape[1]
132
133
134
135
        self.n_classes = n_classes

    def train_dataloader(self):
        return dgl.dataloading.DataLoader(
136
137
138
139
140
141
142
            self.g,
            self.train_idx.to("cuda"),
            self.sampler,
            device="cuda",
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
143
144
            # For CPU sampling, set num_workers to nonzero and use_uva=False
            # Set use_ddp to False for single GPU.
145
146
147
148
            num_workers=0,
            use_uva=True,
            use_ddp=True,
        )
149
150
151

    def val_dataloader(self):
        return dgl.dataloading.DataLoader(
152
153
154
155
156
157
158
159
160
161
162
163
164
165
            self.g,
            self.val_idx.to("cuda"),
            self.sampler,
            device="cuda",
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=0,
            use_uva=True,
        )


if __name__ == "__main__":
    dataset = DglNodePropPredDataset("ogbn-products")
166
    graph, labels = dataset[0]
167
    graph.ndata["label"] = labels.squeeze()
168
    graph.create_formats_()
169
    split_idx = dataset.get_idx_split()
170
171
172
173
174
175
176
177
    train_idx, val_idx, test_idx = (
        split_idx["train"],
        split_idx["valid"],
        split_idx["test"],
    )
    datamodule = DataModule(
        graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes
    )
178
179
180
    model = SAGE(datamodule.in_feats, 256, datamodule.n_classes)

    # Train
181
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
182
    # Use this for single GPU
183
184
185
186
187
188
189
    # trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback])
    trainer = Trainer(
        gpus=[0, 1, 2, 3],
        max_epochs=10,
        callbacks=[checkpoint_callback],
        strategy="ddp_spawn",
    )
190
191
192
    trainer.fit(model, datamodule=datamodule)

    # Test
193
194
195
196
197
    dirs = glob.glob("./lightning_logs/*")
    version = max([int(os.path.split(x)[-1].split("_")[-1]) for x in dirs])
    logdir = "./lightning_logs/version_%d" % version
    print("Evaluating model in", logdir)
    ckpt = glob.glob(os.path.join(logdir, "checkpoints", "*"))[0]
198
199

    model = SAGE.load_from_checkpoint(
200
201
        checkpoint_path=ckpt, hparams_file=os.path.join(logdir, "hparams.yaml")
    ).to("cuda")
202
    with torch.no_grad():
203
        pred = model.inference(graph, "cuda", 4096, 12, graph.device)
204
        pred = pred[test_idx]
205
        label = graph.ndata["label"][test_idx]
206
        acc = MF.accuracy(pred, label)
207
    print("Test accuracy:", acc)