node_classification.py 9.66 KB
Newer Older
1
2
"""
This script demonstrates node classification with GraphSAGE on large graphs, 
3
4
5
6
7
8
9
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently
manages data loading for large datasets, crucial for mini-batch processing.
Post data loading, PyG's user-friendly framework takes over for training,
showcasing seamless integration with GraphBolt. This combination offers an
efficient alternative to traditional Deep Graph Library (DGL) methods,
highlighting adaptability and scalability in handling large-scale graph data
for diverse real-world applications.
10
11

Key Features:
12
13
- Implements the GraphSAGE model, a scalable GNN, for node classification on
  large graphs.
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- The script is well-documented, providing clear explanations at each step.

This flowchart describes the main functional sequence of the provided example.
main: 

main

├───> Load and preprocess dataset (GraphBolt)
│     │
│     └───> Utilize GraphBolt's BuiltinDataset for dataset handling

├───> Instantiate the SAGE model (PyTorch Geometric)
│     │
│     └───> Define the GraphSAGE model architecture

├───> Train the model
│     │
│     ├───> Mini-Batch Processing with GraphBolt
│     │     │
│     │     └───> Efficient handling of mini-batches using GraphBolt's utilities
│     │
│     └───> Training Loop
│           │
│           ├───> Forward and backward passes
│           │
41
42
│           ├───> Convert GraphBolt MiniBatch to PyG Data
│           │
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
│           └───> Parameters optimization

└───> Evaluate the model

      └───> Performance assessment on validation and test datasets

            └───> Accuracy and other relevant metrics calculation


"""

import argparse

import dgl.graphbolt as gb
import torch
import torch.nn.functional as F
import torchmetrics.functional as MF
from torch_geometric.nn import SAGEConv
61
from tqdm import tqdm
62
63
64
65
66
67
68
69
70
71
72


class GraphSAGE(torch.nn.Module):
    #####################################################################
    # (HIGHLIGHT) Define the GraphSAGE model architecture.
    #
    # - This class inherits from `torch.nn.Module`.
    # - Two convolutional layers are created using the SAGEConv class from PyG.
    # - 'in_size', 'hidden_size', 'out_size' are the sizes of
    #   the input, hidden, and output features, respectively.
    # - The forward method defines the computation performed at every call.
73
74
    # - It's adopted from the official PyG example which can be found at
    # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
75
76
77
78
79
80
81
82
    #####################################################################
    def __init__(self, in_size, hidden_size, out_size):
        super(GraphSAGE, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(SAGEConv(in_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, out_size))

83
84
85
86
87
88
89
    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            if i != len(self.layers) - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)
        return x
90

91
    def inference(self, dataloader, x_all, device):
92
93
94
95
96
97
        """Conduct layer-wise inference to get all the node embeddings."""
        for i, layer in tqdm(enumerate(self.layers), "inference"):
            xs = []
            for minibatch in dataloader:
                # Call `to_pyg_data` to convert GB Minibatch to PyG Data.
                pyg_data = minibatch.to_pyg_data()
98
99
                n_id = pyg_data.n_id.to("cpu")
                x = x_all[n_id].to(device)
100
101
                edge_index = pyg_data.edge_index
                x = layer(x, edge_index)
102
                x = x[: pyg_data.batch_size]
103
104
105
106
107
                if i != len(self.layers) - 1:
                    x = x.relu()
                xs.append(x.cpu())
            x_all = torch.cat(xs, dim=0)
        return x_all
108
109


110
111
112
def create_dataloader(
    dataset_set, graph, feature, batch_size, fanout, device, job
):
113
114
    # Initialize an ItemSampler to sample mini-batches from the dataset.
    datapipe = gb.ItemSampler(
115
116
117
118
        dataset_set,
        batch_size=batch_size,
        shuffle=(job == "train"),
        drop_last=(job == "train"),
119
120
    )
    # Sample neighbors for each node in the mini-batch.
121
122
123
124
125
    datapipe = datapipe.sample_neighbor(
        graph, fanout if job != "infer" else [-1]
    )
    # Copy the data to the specified device.
    datapipe = datapipe.copy_to(device=device, extra_attrs=["input_nodes"])
126
127
128
129
130
131
132
133
    # Fetch node features for the sampled subgraph.
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    # Create and return a DataLoader to handle data loading.
    dataloader = gb.DataLoader(datapipe, num_workers=0)

    return dataloader


134
def train(model, dataloader, optimizer):
135
136
137
138
139
140
    model.train()  # Set the model to training mode
    total_loss = 0  # Accumulator for the total loss
    total_correct = 0  # Accumulator for the total number of correct predictions
    total_samples = 0  # Accumulator for the total number of samples processed
    num_batches = 0  # Counter for the number of mini-batches processed

141
142
143
144
145
146
147
148
149
    for _, minibatch in tqdm(enumerate(dataloader), "training"):
        #####################################################################
        # (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.
        #
        # Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class
        # with necessary data and information.
        #####################################################################
        pyg_data = minibatch.to_pyg_data()

150
        optimizer.zero_grad()
151
152
153
        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
        y = pyg_data.y
        loss = F.cross_entropy(out, y)
154
155
        loss.backward()
        optimizer.step()
156
157
158
159

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(y).sum())
        total_samples += y.shape[0]
160
161
162
163
164
165
166
        num_batches += 1
    avg_loss = total_loss / num_batches
    avg_accuracy = total_correct / total_samples
    return avg_loss, avg_accuracy


@torch.no_grad()
167
def evaluate(model, dataloader, num_classes):
168
169
170
    model.eval()
    y_hats = []
    ys = []
171
172
173
174
    for _, minibatch in tqdm(enumerate(dataloader), "evaluating"):
        pyg_data = minibatch.to_pyg_data()
        out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
        y = pyg_data.y
175
        y_hats.append(out)
176
        ys.append(y)
177
178
179
180
181
182
183
184
185

    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )


186
187
@torch.no_grad()
def layerwise_infer(
188
    model, infer_dataloader, test_set, feature, num_classes, device
189
190
191
):
    model.eval()
    features = feature.read("node", None, "feat")
192
    pred = model.inference(infer_dataloader, features, device)
193
194
195
196
197
198
199
200
201
202
203
    pred = pred[test_set._items[0]]
    label = test_set._items[1].to(pred.device)

    return MF.accuracy(
        pred,
        label,
        task="multiclass",
        num_classes=num_classes,
    )


204
205
206
207
208
209
210
def main():
    parser = argparse.ArgumentParser(
        description="Which dataset are you going to use?"
    )
    parser.add_argument(
        "--dataset",
        type=str,
211
212
213
        default="ogbn-products-seeds",
        help='Name of the dataset to use (e.g., "ogbn-products-seeds",'
        + ' "ogbn-arxiv-seeds")',
214
    )
215
216
217
218
219
220
    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of training epochs."
    )
    parser.add_argument(
        "--batch-size", type=int, default=1024, help="Batch size for training."
    )
221
    args = parser.parse_args()
222
223

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
224
225
226
    dataset_name = args.dataset
    dataset = gb.BuiltinDataset(dataset_name).load()
    graph = dataset.graph
227
    feature = dataset.feature.pin_memory_()
228
229
230
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
    test_set = dataset.tasks[0].test_set
231
    all_nodes_set = dataset.all_nodes_set
232
233
234
    num_classes = dataset.tasks[0].metadata["num_classes"]

    train_dataloader = create_dataloader(
235
236
237
238
239
240
241
        train_set,
        graph,
        feature,
        args.batch_size,
        [5, 10, 15],
        device,
        job="train",
242
243
    )
    valid_dataloader = create_dataloader(
244
245
246
247
248
249
250
        valid_set,
        graph,
        feature,
        args.batch_size,
        [5, 10, 15],
        device,
        job="evaluate",
251
    )
252
253
254
255
256
257
258
259
    infer_dataloader = create_dataloader(
        all_nodes_set,
        graph,
        feature,
        4 * args.batch_size,
        [-1],
        device,
        job="infer",
260
261
    )
    in_channels = feature.size("node", None, "feat")[0]
262
    hidden_channels = 256
263
    model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device)
264
265
266
    optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
    for epoch in range(args.epochs):
        train_loss, train_accuracy = train(model, train_dataloader, optimizer)
267

268
        valid_accuracy = evaluate(model, valid_dataloader, num_classes)
269
        print(
270
271
            f"Epoch {epoch}, Train Loss: {train_loss:.4f}, "
            f"Train Accuracy: {train_accuracy:.4f}, "
272
273
            f"Valid Accuracy: {valid_accuracy:.4f}"
        )
274
    test_accuracy = layerwise_infer(
275
        model, infer_dataloader, test_set, feature, num_classes, device
276
    )
277
278
279
280
281
    print(f"Test Accuracy: {test_accuracy:.4f}")


if __name__ == "__main__":
    main()