"src/vscode:/vscode.git/clone" did not exist on "8ae8008b0d096d2b093f5b7c660715a93f74f17a"
node_classification.py 7.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
This flowchart describes the main functional sequence of the provided example.
main

├───> Instantiate DataModule
│     │
│     └───> Load dataset
│     │
│     └───> Create train and valid dataloader[HIGHLIGHT]
│           │
│           └───> ItemSampler (Distribute data to minibatchs)
│           │
13
14
│           └───> sample_neighbor or sample_layer_neighbor
                  (Sample a subgraph for a minibatch)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
│           │
│           └───> fetch_feature (Fetch features for the sampled subgraph)

├───> Instantiate GraphSAGE model
│     │
│     ├───> SAGEConvLayer (input to hidden)
│     │
│     └───> SAGEConvLayer (hidden to hidden)
│     │
│     └───> SAGEConvLayer (hidden to output)
│     │
│     └───> DropoutLayer

└───> Run


      └───> Trainer[HIGHLIGHT]

33
            ├───> SAGE.forward (GraphSAGE model forward pass)
34
35
36
37
38
39
40
41
42
43
44
45

            └───> Validate
"""
import argparse

import dgl.graphbolt as gb
import dgl.nn.pytorch as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
46
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torchmetrics import Accuracy


class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def log_node_and_edge_counts(self, blocks):
        node_counts = [block.num_src_nodes() for block in blocks] + [
            blocks[-1].num_dst_nodes()
        ]
        edge_counts = [block.num_edges() for block in blocks]
        for i, c in enumerate(node_counts):
            self.log(
                f"num_nodes/{i}",
                float(c),
                prog_bar=True,
                on_step=True,
                on_epoch=False,
            )
            if i < len(edge_counts):
                self.log(
                    f"num_edges/{i}",
                    float(edge_counts[i]),
                    prog_bar=True,
                    on_step=True,
                    on_epoch=False,
                )

95
    def training_step(self, batch, batch_idx):
96
97
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
98
99
100
101
102
103
104
105
106
107
108
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )
109
        self.log_node_and_edge_counts(blocks)
110
111
112
        return loss

    def validation_step(self, batch, batch_idx):
113
114
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
115
116
117
118
119
120
121
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
122
            on_step=False,
123
124
125
            on_epoch=True,
            sync_dist=True,
        )
126
        self.log_node_and_edge_counts(blocks)
127
128
129
130
131
132
133
134
135

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, weight_decay=5e-4
        )
        return optimizer


class DataModule(LightningDataModule):
136
    def __init__(self, dataset, fanouts, batch_size, num_workers):
137
138
139
140
141
142
143
144
145
146
        super().__init__()
        self.fanouts = fanouts
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.feature_store = dataset.feature
        self.graph = dataset.graph
        self.train_set = dataset.tasks[0].train_set
        self.valid_set = dataset.tasks[0].validation_set
        self.num_classes = dataset.tasks[0].metadata["num_classes"]

147
    def create_dataloader(self, node_set, is_train):
148
        datapipe = gb.ItemSampler(
149
            node_set,
150
151
152
153
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )
154
155
156
157
158
159
        sampler = (
            datapipe.sample_layer_neighbor
            if is_train
            else datapipe.sample_neighbor
        )
        datapipe = sampler(self.graph, self.fanouts)
160
        datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
161
        datapipe = datapipe.to_dgl()
162
163
164
165
166
        dataloader = gb.MultiProcessDataLoader(
            datapipe, num_workers=self.num_workers
        )
        return dataloader

167
168
169
170
171
172
173
174
175
176
    ########################################################################
    # (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
    # essential components of the Lightning framework, defining how data is
    # loaded during training and validation. In this example, we utilize a
    # specialized 'graphbolt dataloader', which are concatenated by a series
    # of datapipes, for these purposes.
    ########################################################################
    def train_dataloader(self):
        return self.create_dataloader(self.train_set, is_train=True)

177
    def val_dataloader(self):
178
        return self.create_dataloader(self.valid_set, is_train=False)
179
180
181
182


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
183
        description="GNN baselines on ogbn-products data with GraphBolt"
184
185
186
187
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
188
189
        default=1,
        help="number of GPUs used for computing (default: 1)",
190
191
192
193
    )
    parser.add_argument(
        "--batch_size",
        type=int,
194
195
        default=1024,
        help="input batch size for training (default: 1024)",
196
197
198
199
    )
    parser.add_argument(
        "--epochs",
        type=int,
200
201
        default=40,
        help="number of epochs to train (default: 40)",
202
203
204
205
206
207
208
209
210
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    args = parser.parse_args()

211
212
213
214
215
216
217
218
219
    dataset = gb.BuiltinDataset("ogbn-products").load()
    datamodule = DataModule(
        dataset,
        [15, 10, 5],
        args.batch_size,
        args.num_workers,
    )
    in_size = dataset.feature.size("node", None, "feat")[0]
    model = SAGE(in_size, 256, datamodule.num_classes).to(torch.double)
220
221

    # Train.
222
223
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
    early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
224
225
226
227
228
229
230
231
232
233
    ########################################################################
    # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
    # everything after defining `LightningDataModule` and
    # `LightningDataModule`. More details can be found in
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html.
    ########################################################################
    trainer = Trainer(
        accelerator="gpu",
        devices=args.num_gpus,
        max_epochs=args.epochs,
234
        callbacks=[checkpoint_callback, early_stopping_callback],
235
236
    )
    trainer.fit(model, datamodule=datamodule)