node_classification.py 7.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
This flowchart describes the main functional sequence of the provided example.
main

├───> Instantiate DataModule
│     │
│     └───> Load dataset
│     │
│     └───> Create train and valid dataloader[HIGHLIGHT]
│           │
│           └───> ItemSampler (Distribute data to minibatchs)
│           │
13
14
│           └───> sample_neighbor or sample_layer_neighbor
                  (Sample a subgraph for a minibatch)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
│           │
│           └───> fetch_feature (Fetch features for the sampled subgraph)

├───> Instantiate GraphSAGE model
│     │
│     ├───> SAGEConvLayer (input to hidden)
│     │
│     └───> SAGEConvLayer (hidden to hidden)
│     │
│     └───> SAGEConvLayer (hidden to output)
│     │
│     └───> DropoutLayer

└───> Run


      └───> Trainer[HIGHLIGHT]

33
            ├───> SAGE.forward (GraphSAGE model forward pass)
34
35
36
37
38
39
40
41
42
43
44
45

            └───> Validate
"""
import argparse

import dgl.graphbolt as gb
import dgl.nn.pytorch as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
46
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torchmetrics import Accuracy


class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def log_node_and_edge_counts(self, blocks):
        node_counts = [block.num_src_nodes() for block in blocks] + [
            blocks[-1].num_dst_nodes()
        ]
        edge_counts = [block.num_edges() for block in blocks]
        for i, c in enumerate(node_counts):
            self.log(
                f"num_nodes/{i}",
                float(c),
                prog_bar=True,
                on_step=True,
                on_epoch=False,
            )
            if i < len(edge_counts):
                self.log(
                    f"num_edges/{i}",
                    float(edge_counts[i]),
                    prog_bar=True,
                    on_step=True,
                    on_epoch=False,
                )

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def training_step(self, batch, batch_idx):
        # TODO: Move this to the data pipeline as a stage.
        blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
        x = blocks[0].srcdata["feat"]
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )
110
        self.log_node_and_edge_counts(blocks)
111
112
113
114
115
116
117
118
119
120
121
122
        return loss

    def validation_step(self, batch, batch_idx):
        blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
        x = blocks[0].srcdata["feat"]
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
123
            on_step=False,
124
125
126
            on_epoch=True,
            sync_dist=True,
        )
127
        self.log_node_and_edge_counts(blocks)
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, weight_decay=5e-4
        )
        return optimizer


class DataModule(LightningDataModule):
    def __init__(self, fanouts, batch_size, num_workers):
        super().__init__()
        self.fanouts = fanouts
        self.batch_size = batch_size
        self.num_workers = num_workers
142
        dataset = gb.BuiltinDataset("ogbn-products").load()
143
144
145
146
147
148
        self.feature_store = dataset.feature
        self.graph = dataset.graph
        self.train_set = dataset.tasks[0].train_set
        self.valid_set = dataset.tasks[0].validation_set
        self.num_classes = dataset.tasks[0].metadata["num_classes"]

149
    def create_dataloader(self, node_set, is_train):
150
        datapipe = gb.ItemSampler(
151
            node_set,
152
153
154
155
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )
156
157
158
159
160
161
        sampler = (
            datapipe.sample_layer_neighbor
            if is_train
            else datapipe.sample_neighbor
        )
        datapipe = sampler(self.graph, self.fanouts)
162
163
164
165
166
167
        datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
        dataloader = gb.MultiProcessDataLoader(
            datapipe, num_workers=self.num_workers
        )
        return dataloader

168
169
170
171
172
173
174
175
176
177
    ########################################################################
    # (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
    # essential components of the Lightning framework, defining how data is
    # loaded during training and validation. In this example, we utilize a
    # specialized 'graphbolt dataloader', which are concatenated by a series
    # of datapipes, for these purposes.
    ########################################################################
    def train_dataloader(self):
        return self.create_dataloader(self.train_set, is_train=True)

178
    def val_dataloader(self):
179
        return self.create_dataloader(self.valid_set, is_train=False)
180
181
182
183


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
184
        description="GNN baselines on ogbn-products data with GraphBolt"
185
186
187
188
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
189
190
        default=1,
        help="number of GPUs used for computing (default: 1)",
191
192
193
194
    )
    parser.add_argument(
        "--batch_size",
        type=int,
195
196
        default=1024,
        help="input batch size for training (default: 1024)",
197
198
199
200
    )
    parser.add_argument(
        "--epochs",
        type=int,
201
202
        default=40,
        help="number of epochs to train (default: 40)",
203
204
205
206
207
208
209
210
211
212
213
214
215
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    args = parser.parse_args()

    datamodule = DataModule([15, 10, 5], args.batch_size, args.num_workers)
    model = SAGE(100, 256, datamodule.num_classes).to(torch.double)

    # Train.
216
217
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
    early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
218
219
220
221
222
223
224
225
226
227
    ########################################################################
    # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
    # everything after defining `LightningDataModule` and
    # `LightningDataModule`. More details can be found in
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html.
    ########################################################################
    trainer = Trainer(
        accelerator="gpu",
        devices=args.num_gpus,
        max_epochs=args.epochs,
228
        callbacks=[checkpoint_callback, early_stopping_callback],
229
230
    )
    trainer.fit(model, datamodule=datamodule)