node_classification.py 7.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
This flowchart describes the main functional sequence of the provided example.
main

├───> Instantiate DataModule
│     │
│     └───> Load dataset
│     │
│     └───> Create train and valid dataloader[HIGHLIGHT]
│           │
│           └───> ItemSampler (Distribute data to minibatchs)
│           │
13
14
│           └───> sample_neighbor or sample_layer_neighbor
                  (Sample a subgraph for a minibatch)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
│           │
│           └───> fetch_feature (Fetch features for the sampled subgraph)

├───> Instantiate GraphSAGE model
│     │
│     ├───> SAGEConvLayer (input to hidden)
│     │
│     └───> SAGEConvLayer (hidden to hidden)
│     │
│     └───> SAGEConvLayer (hidden to output)
│     │
│     └───> DropoutLayer

└───> Run


      └───> Trainer[HIGHLIGHT]

33
            ├───> SAGE.forward (GraphSAGE model forward pass)
34
35
36
37
38
39
40
41
42
43
44
45

            └───> Validate
"""
import argparse

import dgl.graphbolt as gb
import dgl.nn.pytorch as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
46
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torchmetrics import Accuracy


class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def log_node_and_edge_counts(self, blocks):
        node_counts = [block.num_src_nodes() for block in blocks] + [
            blocks[-1].num_dst_nodes()
        ]
        edge_counts = [block.num_edges() for block in blocks]
        for i, c in enumerate(node_counts):
            self.log(
                f"num_nodes/{i}",
                float(c),
                prog_bar=True,
                on_step=True,
                on_epoch=False,
            )
            if i < len(edge_counts):
                self.log(
                    f"num_edges/{i}",
                    float(edge_counts[i]),
                    prog_bar=True,
                    on_step=True,
                    on_epoch=False,
                )

95
    def training_step(self, batch, batch_idx):
96
97
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
98
99
100
101
102
103
104
105
106
107
108
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )
109
        self.log_node_and_edge_counts(blocks)
110
111
112
        return loss

    def validation_step(self, batch, batch_idx):
113
114
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
115
116
117
118
119
120
121
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
122
            on_step=False,
123
124
125
            on_epoch=True,
            sync_dist=True,
        )
126
        self.log_node_and_edge_counts(blocks)
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, weight_decay=5e-4
        )
        return optimizer


class DataModule(LightningDataModule):
    def __init__(self, fanouts, batch_size, num_workers):
        super().__init__()
        self.fanouts = fanouts
        self.batch_size = batch_size
        self.num_workers = num_workers
141
        dataset = gb.BuiltinDataset("ogbn-products").load()
142
143
144
145
146
147
        self.feature_store = dataset.feature
        self.graph = dataset.graph
        self.train_set = dataset.tasks[0].train_set
        self.valid_set = dataset.tasks[0].validation_set
        self.num_classes = dataset.tasks[0].metadata["num_classes"]

148
    def create_dataloader(self, node_set, is_train):
149
        datapipe = gb.ItemSampler(
150
            node_set,
151
152
153
154
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )
155
156
157
158
159
160
        sampler = (
            datapipe.sample_layer_neighbor
            if is_train
            else datapipe.sample_neighbor
        )
        datapipe = sampler(self.graph, self.fanouts)
161
        datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
162
        datapipe = datapipe.to_dgl()
163
164
165
166
167
        dataloader = gb.MultiProcessDataLoader(
            datapipe, num_workers=self.num_workers
        )
        return dataloader

168
169
170
171
172
173
174
175
176
177
    ########################################################################
    # (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
    # essential components of the Lightning framework, defining how data is
    # loaded during training and validation. In this example, we utilize a
    # specialized 'graphbolt dataloader', which are concatenated by a series
    # of datapipes, for these purposes.
    ########################################################################
    def train_dataloader(self):
        return self.create_dataloader(self.train_set, is_train=True)

178
    def val_dataloader(self):
179
        return self.create_dataloader(self.valid_set, is_train=False)
180
181
182
183


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
184
        description="GNN baselines on ogbn-products data with GraphBolt"
185
186
187
188
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
189
190
        default=1,
        help="number of GPUs used for computing (default: 1)",
191
192
193
194
    )
    parser.add_argument(
        "--batch_size",
        type=int,
195
196
        default=1024,
        help="input batch size for training (default: 1024)",
197
198
199
200
    )
    parser.add_argument(
        "--epochs",
        type=int,
201
202
        default=40,
        help="number of epochs to train (default: 40)",
203
204
205
206
207
208
209
210
211
212
213
214
215
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    args = parser.parse_args()

    datamodule = DataModule([15, 10, 5], args.batch_size, args.num_workers)
    model = SAGE(100, 256, datamodule.num_classes).to(torch.double)

    # Train.
216
217
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
    early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
218
219
220
221
222
223
224
225
226
227
    ########################################################################
    # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
    # everything after defining `LightningDataModule` and
    # `LightningDataModule`. More details can be found in
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html.
    ########################################################################
    trainer = Trainer(
        accelerator="gpu",
        devices=args.num_gpus,
        max_epochs=args.epochs,
228
        callbacks=[checkpoint_callback, early_stopping_callback],
229
230
    )
    trainer.fit(model, datamodule=datamodule)