node_classification.py 6.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import tqdm
import glob
import os
from ogb.nodeproppred import DglNodePropPredDataset
from torchmetrics import Accuracy
import torchmetrics.functional as MF
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningDataModule, LightningModule, Trainer

class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
44
        dataloader = dgl.dataloading.DataLoader(
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
                persistent_workers=(num_workers > 0))
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
            y = torch.zeros(
                g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
                device=buffer_device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = blocks[0].srcdata['h']
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                y[output_nodes] = h.to(buffer_device)
            g.ndata['h'] = y
        return y

    def training_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
        x = blocks[0].srcdata['feat']
        y = blocks[-1].dstdata['label']
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
        x = blocks[0].srcdata['feat']
        y = blocks[-1].dstdata['label']
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
        return optimizer

class DataModule(LightningDataModule):
    def __init__(self, graph, train_idx, val_idx, fanouts, batch_size, n_classes):
        super().__init__()

        sampler = dgl.dataloading.NeighborSampler(
            fanouts, prefetch_node_feats=['feat'], prefetch_labels=['label'])

        self.g = graph
        self.train_idx, self.val_idx = train_idx, val_idx
        self.sampler = sampler
        self.batch_size = batch_size
        self.in_feats = graph.ndata['feat'].shape[1]
        self.n_classes = n_classes

    def train_dataloader(self):
        return dgl.dataloading.DataLoader(
            self.g, self.train_idx.to('cuda'), self.sampler,
            device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False,
            # For CPU sampling, set num_workers to nonzero and use_uva=False
            # Set use_ddp to False for single GPU.
            num_workers=0, use_uva=True, use_ddp=True)

    def val_dataloader(self):
        return dgl.dataloading.DataLoader(
            self.g, self.val_idx.to('cuda'), self.sampler,
            device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False,
            num_workers=0, use_uva=True)

if __name__ == '__main__':
    dataset = DglNodePropPredDataset('ogbn-products')
    graph, labels = dataset[0]
    graph.ndata['label'] = labels.squeeze()
119
    graph.create_formats_()
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    split_idx = dataset.get_idx_split()
    train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
    datamodule = DataModule(graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes)
    model = SAGE(datamodule.in_feats, 256, datamodule.n_classes)

    # Train
    checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1)
    # Use this for single GPU
    #trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback])
    trainer = Trainer(gpus=[0, 1, 2, 3], max_epochs=10, callbacks=[checkpoint_callback], strategy='ddp_spawn')
    trainer.fit(model, datamodule=datamodule)

    # Test
    dirs = glob.glob('./lightning_logs/*')
    version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs])
    logdir = './lightning_logs/version_%d' % version
    print('Evaluating model in', logdir)
    ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0]

    model = SAGE.load_from_checkpoint(
        checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to('cuda')
    with torch.no_grad():
        pred = model.inference(graph, 'cuda', 4096, 12, graph.device)
        pred = pred[test_idx]
        label = graph.ndata['label'][test_idx]
        acc = MF.accuracy(pred, label)
    print('Test accuracy:', acc)