node_classification.py 9.67 KB
Newer Older
1
2
"""
This script demonstrates node classification with GraphSAGE on large graphs, 
3
4
5
6
7
8
9
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently
manages data loading for large datasets, crucial for mini-batch processing.
Post data loading, PyG's user-friendly framework takes over for training,
showcasing seamless integration with GraphBolt. This combination offers an
efficient alternative to traditional Deep Graph Library (DGL) methods,
highlighting adaptability and scalability in handling large-scale graph data
for diverse real-world applications.
10
11

Key Features:
12
13
- Implements the GraphSAGE model, a scalable GNN, for node classification on
  large graphs.
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- The script is well-documented, providing clear explanations at each step.

This flowchart describes the main functional sequence of the provided example.
main: 

main

├───> Load and preprocess dataset (GraphBolt)
│     │
│     └───> Utilize GraphBolt's BuiltinDataset for dataset handling

├───> Instantiate the SAGE model (PyTorch Geometric)
│     │
│     └───> Define the GraphSAGE model architecture

├───> Train the model
│     │
│     ├───> Mini-Batch Processing with GraphBolt
│     │     │
│     │     └───> Efficient handling of mini-batches using GraphBolt's utilities
│     │
│     └───> Training Loop
│           │
│           ├───> Forward and backward passes
│           │
41
42
│           ├───> Convert GraphBolt MiniBatch to PyG Data
│           │
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
│           └───> Parameters optimization

└───> Evaluate the model

      └───> Performance assessment on validation and test datasets

            └───> Accuracy and other relevant metrics calculation


"""

import argparse

import dgl.graphbolt as gb
import torch
import torch.nn.functional as F
import torchmetrics.functional as MF
from torch_geometric.nn import SAGEConv
61
from tqdm import tqdm
62
63
64
65
66
67
68
69
70
71
72


class GraphSAGE(torch.nn.Module):
    #####################################################################
    # (HIGHLIGHT) Define the GraphSAGE model architecture.
    #
    # - This class inherits from `torch.nn.Module`.
    # - Two convolutional layers are created using the SAGEConv class from PyG.
    # - 'in_size', 'hidden_size', 'out_size' are the sizes of
    #   the input, hidden, and output features, respectively.
    # - The forward method defines the computation performed at every call.
73
74
    # - It's adopted from the official PyG example which can be found at
    # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
75
76
77
78
79
80
81
82
    #####################################################################
    def __init__(self, in_size, hidden_size, out_size):
        super(GraphSAGE, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(SAGEConv(in_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, out_size))

83
84
85
86
87
88
89
    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            if i != len(self.layers) - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
        return x
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def inference(self, args, dataloader, x_all, device):
        """Conduct layer-wise inference to get all the node embeddings."""
        for i, layer in tqdm(enumerate(self.layers), "inference"):
            xs = []
            for minibatch in dataloader:
                # Call `to_pyg_data` to convert GB Minibatch to PyG Data.
                pyg_data = minibatch.to_pyg_data()
                n_ids = minibatch.node_ids().to("cpu")
                x = x_all[n_ids].to(device)
                edge_index = pyg_data.edge_index
                x = layer(x, edge_index)
                x = x[: 4 * args.batch_size]
                if i != len(self.layers) - 1:
                    x = x.relu()
                xs.append(x.cpu())
            x_all = torch.cat(xs, dim=0)
        return x_all
108
109


110
111
112
def create_dataloader(
    dataset_set, graph, feature, batch_size, fanout, device, job
):
113
114
    # Initialize an ItemSampler to sample mini-batches from the dataset.
    datapipe = gb.ItemSampler(
115
116
117
118
        dataset_set,
        batch_size=batch_size,
        shuffle=(job == "train"),
        drop_last=(job == "train"),
119
120
    )
    # Sample neighbors for each node in the mini-batch.
121
122
123
124
125
    datapipe = datapipe.sample_neighbor(
        graph, fanout if job != "infer" else [-1]
    )
    # Copy the data to the specified device.
    datapipe = datapipe.copy_to(device=device, extra_attrs=["input_nodes"])
126
127
128
129
130
131
132
133
    # Fetch node features for the sampled subgraph.
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    # Create and return a DataLoader to handle data loading.
    dataloader = gb.DataLoader(datapipe, num_workers=0)

    return dataloader


134
def train(model, dataloader, optimizer):
135
136
137
138
139
140
    model.train()  # Set the model to training mode
    total_loss = 0  # Accumulator for the total loss
    total_correct = 0  # Accumulator for the total number of correct predictions
    total_samples = 0  # Accumulator for the total number of samples processed
    num_batches = 0  # Counter for the number of mini-batches processed

141
142
143
144
145
146
147
148
149
    for _, minibatch in tqdm(enumerate(dataloader), "training"):
        #####################################################################
        # (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.
        #
        # Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class
        # with necessary data and information.
        #####################################################################
        pyg_data = minibatch.to_pyg_data()

150
        optimizer.zero_grad()
151
152
153
        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
        y = pyg_data.y
        loss = F.cross_entropy(out, y)
154
155
        loss.backward()
        optimizer.step()
156
157
158
159

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(y).sum())
        total_samples += y.shape[0]
160
161
162
163
164
165
166
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_accuracy = total_correct / total_samples
    return avg_loss, avg_accuracy


@torch.no_grad()
167
def evaluate(model, dataloader, num_classes):
168
169
170
    model.eval()
    y_hats = []
    ys = []
171
172
173
174
    for _, minibatch in tqdm(enumerate(dataloader), "evaluating"):
        pyg_data = minibatch.to_pyg_data()
        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
        y = pyg_data.y
175
        y_hats.append(out)
176
        ys.append(y)
177
178
179
180
181
182
183
184
185

    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@torch.no_grad()
def layerwise_infer(
    model, args, infer_dataloader, test_set, feature, num_classes, device
):
    model.eval()
    features = feature.read("node", None, "feat")
    pred = model.inference(args, infer_dataloader, features, device)
    pred = pred[test_set._items[0]]
    label = test_set._items[1].to(pred.device)

    return MF.accuracy(
        pred,
        label,
        task="multiclass",
        num_classes=num_classes,
    )


204
205
206
207
208
209
210
def main():
    parser = argparse.ArgumentParser(
        description="Which dataset are you going to use?"
    )
    parser.add_argument(
        "--dataset",
        type=str,
211
        default="ogbn-products",
212
213
        help='Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")',
    )
214
215
216
217
218
219
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of training epochs."
    )
    parser.add_argument(
        "--batch-size", type=int, default=1024, help="Batch size for training."
    )
220
    args = parser.parse_args()
221
222

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
224
225
    dataset_name = args.dataset
    dataset = gb.BuiltinDataset(dataset_name).load()
    graph = dataset.graph
226
    feature = dataset.feature.pin_memory_()
227
228
229
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
    test_set = dataset.tasks[0].test_set
230
    all_nodes_set = dataset.all_nodes_set
231
232
233
    num_classes = dataset.tasks[0].metadata["num_classes"]

    train_dataloader = create_dataloader(
234
235
236
237
238
239
240
        train_set,
        graph,
        feature,
        args.batch_size,
        [5, 10, 15],
        device,
        job="train",
241
242
    )
    valid_dataloader = create_dataloader(
243
244
245
246
247
248
249
        valid_set,
        graph,
        feature,
        args.batch_size,
        [5, 10, 15],
        device,
        job="evaluate",
250
    )
251
252
253
254
255
256
257
258
    infer_dataloader = create_dataloader(
        all_nodes_set,
        graph,
        feature,
        4 * args.batch_size,
        [-1],
        device,
        job="infer",
259
260
    )
    in_channels = feature.size("node", None, "feat")[0]
261
    hidden_channels = 256
262
    model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device)
263
264
265
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
    for epoch in range(args.epochs):
        train_loss, train_accuracy = train(model, train_dataloader, optimizer)
266

267
        valid_accuracy = evaluate(model, valid_dataloader, num_classes)
268
        print(
269
270
            f"Epoch {epoch}, Train Loss: {train_loss:.4f}, "
            f"Train Accuracy: {train_accuracy:.4f}, "
271
272
            f"Valid Accuracy: {valid_accuracy:.4f}"
        )
273
274
275
    test_accuracy = layerwise_infer(
        model, args, infer_dataloader, test_set, feature, num_classes, device
    )
276
277
278
279
280
    print(f"Test Accuracy: {test_accuracy:.4f}")


if __name__ == "__main__":
    main()