"vscode:/vscode.git/clone" did not exist on "22b45304bf85a3c5281753d6b3259ccaf96e5085"
node_classification.py 7.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
This flowchart describes the main functional sequence of the provided example.
main

├───> Instantiate DataModule
│     │
│     └───> Load dataset
│     │
│     └───> Create train and valid dataloader[HIGHLIGHT]
│           │
│           └───> ItemSampler (Distribute data to minibatchs)
│           │
13
14
│           └───> sample_neighbor or sample_layer_neighbor
                  (Sample a subgraph for a minibatch)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
│           │
│           └───> fetch_feature (Fetch features for the sampled subgraph)

├───> Instantiate GraphSAGE model
│     │
│     ├───> SAGEConvLayer (input to hidden)
│     │
│     └───> SAGEConvLayer (hidden to hidden)
│     │
│     └───> SAGEConvLayer (hidden to output)
│     │
│     └───> DropoutLayer

└───> Run


      └───> Trainer[HIGHLIGHT]

33
            ├───> SAGE.forward (GraphSAGE model forward pass)
34
35
36
37
38
39
40
41
42
43
44
45

            └───> Validate
"""
import argparse

import dgl.graphbolt as gb
import dgl.nn.pytorch as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
46
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from torchmetrics import Accuracy


class SAGE(LightningModule):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def log_node_and_edge_counts(self, blocks):
        node_counts = [block.num_src_nodes() for block in blocks] + [
            blocks[-1].num_dst_nodes()
        ]
        edge_counts = [block.num_edges() for block in blocks]
        for i, c in enumerate(node_counts):
            self.log(
                f"num_nodes/{i}",
                float(c),
                prog_bar=True,
                on_step=True,
                on_epoch=False,
            )
            if i < len(edge_counts):
                self.log(
                    f"num_edges/{i}",
                    float(edge_counts[i]),
                    prog_bar=True,
                    on_step=True,
                    on_epoch=False,
                )

95
    def training_step(self, batch, batch_idx):
96
        batch = batch.to_dgl()
97
98
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
99
100
101
102
103
104
105
106
107
108
109
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
        )
110
        self.log_node_and_edge_counts(blocks)
111
112
113
        return loss

    def validation_step(self, batch, batch_idx):
114
        batch = batch.to_dgl()
115
116
        blocks = [block.to("cuda") for block in batch.blocks]
        x = batch.node_features["feat"]
117
118
119
120
121
122
123
        y = batch.labels.to("cuda")
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
124
            on_step=False,
125
126
127
            on_epoch=True,
            sync_dist=True,
        )
128
        self.log_node_and_edge_counts(blocks)
129
130
131
132
133
134
135
136
137

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, weight_decay=5e-4
        )
        return optimizer


class DataModule(LightningDataModule):
138
    def __init__(self, dataset, fanouts, batch_size, num_workers):
139
140
141
142
143
144
145
146
147
148
        super().__init__()
        self.fanouts = fanouts
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.feature_store = dataset.feature
        self.graph = dataset.graph
        self.train_set = dataset.tasks[0].train_set
        self.valid_set = dataset.tasks[0].validation_set
        self.num_classes = dataset.tasks[0].metadata["num_classes"]

149
    def create_dataloader(self, node_set, is_train):
150
        datapipe = gb.ItemSampler(
151
            node_set,
152
153
154
155
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )
156
157
158
159
160
161
        sampler = (
            datapipe.sample_layer_neighbor
            if is_train
            else datapipe.sample_neighbor
        )
        datapipe = sampler(self.graph, self.fanouts)
162
        datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
163
        dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers)
164
165
        return dataloader

166
167
168
169
170
171
172
173
174
175
    ########################################################################
    # (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
    # essential components of the Lightning framework, defining how data is
    # loaded during training and validation. In this example, we utilize a
    # specialized 'graphbolt dataloader', which are concatenated by a series
    # of datapipes, for these purposes.
    ########################################################################
    def train_dataloader(self):
        return self.create_dataloader(self.train_set, is_train=True)

176
    def val_dataloader(self):
177
        return self.create_dataloader(self.valid_set, is_train=False)
178
179
180
181


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
182
        description="GNN baselines on ogbn-products data with GraphBolt"
183
184
185
186
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
187
188
        default=1,
        help="number of GPUs used for computing (default: 1)",
189
190
191
192
    )
    parser.add_argument(
        "--batch_size",
        type=int,
193
194
        default=1024,
        help="input batch size for training (default: 1024)",
195
196
197
198
    )
    parser.add_argument(
        "--epochs",
        type=int,
199
200
        default=40,
        help="number of epochs to train (default: 40)",
201
202
203
204
205
206
207
208
209
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="number of workers (default: 0)",
    )
    args = parser.parse_args()

210
211
212
    dataset = gb.BuiltinDataset("ogbn-products").load()
    datamodule = DataModule(
        dataset,
213
        [10, 10, 10],
214
215
216
217
        args.batch_size,
        args.num_workers,
    )
    in_size = dataset.feature.size("node", None, "feat")[0]
218
    model = SAGE(in_size, 256, datamodule.num_classes)
219
220

    # Train.
221
222
    checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
    early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
223
224
225
226
227
228
229
230
231
232
    ########################################################################
    # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
    # everything after defining `LightningDataModule` and
    # `LightningDataModule`. More details can be found in
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html.
    ########################################################################
    trainer = Trainer(
        accelerator="gpu",
        devices=args.num_gpus,
        max_epochs=args.epochs,
233
        callbacks=[checkpoint_callback, early_stopping_callback],
234
235
    )
    trainer.fit(model, datamodule=datamodule)