node_classification.py 11 KB
Newer Older
1
2
3
4
5
6
7
"""
This script trains and tests a GraphSAGE model for node classification on
large graphs using efficient neighbor sampling.

Paper: [Inductive Representation Learning on Large Graphs]
(https://arxiv.org/abs/1706.02216)

8
9
10
11
12
13
Before reading this example, please familiar yourself with graphsage node
classification by reading the example in the
`examples/core/graphsage/node_classification.py`

If you want to train graphsage on a large graph in a distributed fashion, read
the example in the `examples/distributed/graphsage/`.
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

This flowchart describes the main functional sequence of the provided example.
main

├───> Load and preprocess dataset

├───> Instantiate SAGE model

├───> train
│     │
│     ├───> NeighborSampler (HIGHLIGHT)
│     │
│     └───> Training loop
│           │
│           └───> SAGE.forward

└───> layerwise_infer

      └───> SAGE.inference

            └───> MultiLayerFullNeighborSampler (HIGHLIGHT)
"""

import argparse
38
import time
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
from ogb.nodeproppred import DglNodePropPredDataset


class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

78
    def inference(self, g, device, batch_size, fused_sampling: bool = True):
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        """Conduct layer-wise inference to get all the node embeddings."""
        feat = g.ndata["feat"]
        #####################################################################
        # (HIGHLIGHT) Creating a MultiLayerFullNeighborSampler instance.
        # This sampler is used in the Graph Neural Networks (GNN) training
        # process to provide neighbor sampling, which is crucial for
        # efficient training of GNN on large graphs.
        #
        # The first argument '1' indicates the number of layers for
        # the neighbor sampling. In this case, it's set to 1, meaning
        # only the direct neighbors of each node will be included in the
        # sampling.
        #
        # The 'prefetch_node_feats' parameter specifies the node features
        # that need to be pre-fetched during sampling. In this case, the
        # feature named 'feat' will be pre-fetched.
        #
        # `prefetch` in DGL initiates data fetching operations in parallel
        # with model computations. This ensures data is ready when the
        # computation needs it, thereby eliminating waiting times between
        # fetching and computing steps and reducing the I/O overhead during
        # the training process.
        #
        # The difference between whether to use prefetch or not is shown:
        #
        # Without Prefetch:
        # Fetch1 ──> Compute1 ──> Fetch2 ──> Compute2 ──> Fetch3 ──> Compute3
        #
        # With Prefetch:
        # Fetch1 ──> Fetch2 ──> Fetch3
        #    │          │          │
        #    └─Compute1 └─Compute2 └─Compute3
        #####################################################################
112
113
114
        sampler = MultiLayerFullNeighborSampler(
            1, prefetch_node_feats=["feat"], fused=fused_sampling
        )
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        dataloader = DataLoader(
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        # Enable pin_memory for faster CPU to GPU data transfer if the
        # model is running on a GPU.
        pin_memory = buffer_device != device

        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx == len(self.layers) - 1
            y = torch.empty(
                g.num_nodes(),
                self.out_size if is_last_layer else self.hidden_size,
                device=buffer_device,
                pin_memory=pin_memory,
            )
            feat = feat.to(device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = feat[input_nodes]
                hidden_x = layer(blocks[0], x)  # len(blocks) = 1
                if layer_idx != len(self.layers) - 1:
                    hidden_x = F.relu(hidden_x)
                    hidden_x = self.dropout(hidden_x)
                # By design, our output nodes are contiguous.
                y[output_nodes[0] : output_nodes[-1] + 1] = hidden_x.to(
                    buffer_device
                )
            feat = y
        return y


@torch.no_grad()
def evaluate(model, graph, dataloader, num_classes):
    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
        x = blocks[0].srcdata["feat"]
        ys.append(blocks[-1].dstdata["label"])
        y_hats.append(model(blocks, x))
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


@torch.no_grad()
172
173
174
def layerwise_infer(
    device, graph, nid, model, num_classes, batch_size, fused_sampling
):
175
    model.eval()
176
177
178
    pred = model.inference(
        graph, device, batch_size, fused_sampling
    )  # pred in buffer_device.
179
180
181
182
183
    pred = pred[nid]
    label = graph.ndata["label"][nid].to(pred.device)
    return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes)


184
def train(device, g, dataset, model, num_classes, use_uva, fused_sampling):
185
    # Create sampler & dataloader.
186
187
    train_idx = dataset.train_idx.to(g.device if not use_uva else device)
    val_idx = dataset.val_idx.to(g.device if not use_uva else device)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    #####################################################################
    # (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
    # training of Graph Neural Networks (GNNs) on large-scale graphs.
    #
    # The argument [10, 10, 10] sets the number of neighbors (fanout)
    # to be sampled at each layer. Here, we have three layers, and
    # 10 neighbors will be randomly selected for each node at each
    # layer.
    #
    # The 'prefetch_node_feats' and 'prefetch_labels' parameters
    # specify the node features and labels that need to be pre-fetched
    # during sampling. More details about `prefetch` can be found in the
    # `SAGE.inference` function.
    #####################################################################
    sampler = NeighborSampler(
        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]
        prefetch_node_feats=["feat"],
        prefetch_labels=["label"],
206
        fused=fused_sampling,
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    )

    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        # If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,
        # otherwise it will cause error.
        num_workers=0,
        use_uva=use_uva,
    )

    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        # No need to shuffle for validation.
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

    for epoch in range(10):
239
        t0 = time.time()
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        model.train()
        total_loss = 0
        # A block is a graph consisting of two sets of nodes: the
        # source nodes and destination nodes. The source and destination
        # nodes can have multiple node types. All the edges connect from
        # source nodes to destination nodes.
        # For more details: https://discuss.dgl.ai/t/what-is-the-block/2932.
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            # The input features from the source nodes in the first layer's
            # computation graph.
            x = blocks[0].srcdata["feat"]

            # The ground truth labels from the destination nodes
            # in the last layer's computation graph.
            y = blocks[-1].dstdata["label"]

            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
264
        t1 = time.time()
265
266
267
        acc = evaluate(model, g, val_dataloader, num_classes)
        print(
            f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f} | "
268
            f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
269
270
271
272
273
274
275
276
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        default="mixed",
277
        choices=["cpu", "mixed", "gpu", "compare-to-graphbolt"],
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        help="Training mode. 'cpu' for CPU training, 'mixed' for "
        "CPU-GPU mixed training, 'gpu' for pure-GPU training.",
    )
    args = parser.parse_args()
    if not torch.cuda.is_available():
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")

    # Load and preprocess dataset.
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
    g = dataset[0]
    g = g.to("cuda" if args.mode == "gpu" else "cpu")
    num_classes = dataset.num_classes
    # Whether use Unified Virtual Addressing (UVA) for CUDA computation.
    use_uva = args.mode == "mixed"
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
295
    fused_sampling = args.mode != "compare-to-graphbolt"
296
297
298
299
300
301
302
303

    # Create GraphSAGE model.
    in_size = g.ndata["feat"].shape[1]
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size).to(device)

    # Model training.
    print("Training...")
304
    train(device, g, dataset, model, num_classes, use_uva, fused_sampling)
305
306
307
308

    # Test the model.
    print("Testing...")
    acc = layerwise_infer(
309
310
311
312
313
314
315
        device,
        g,
        dataset.test_idx,
        model,
        num_classes,
        batch_size=4096,
        fused_sampling=fused_sampling,
316
    )
317
    print(f"Test accuracy {acc.item():.4f}")