node_classification.py 4.36 KB
Newer Older
1
"""
2
3
This example shows how to create a GraphBolt dataloader to sample and train a
node classification model with the Cora dataset.
4
5
6
7
8
9
10
11
12
13
14
15
"""
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF


############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
16
def create_dataloader(dataset, itemset, device):
17
    # Sample seed nodes from the itemset.
18
    datapipe = gb.ItemSampler(itemset, batch_size=16)
19

20
    # Copy the mini-batch to the designated device for sampling and training.
21
    datapipe = datapipe.copy_to(device, extra_attrs=["seeds"])
22

23
24
25
26
27
28
29
30
31
    # Sample neighbors for the seed nodes.
    datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])

    # Fetch features for sampled nodes.
    datapipe = datapipe.fetch_feature(
        dataset.feature, node_feature_keys=["feat"]
    )

    # Initiate the dataloader for the datapipe.
32
    return gb.DataLoader(datapipe)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


class GCN(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=16):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.GraphConv(in_size, hidden_size))
        self.layers.append(dglnn.GraphConv(hidden_size, out_size))

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x


@torch.no_grad()
def evaluate(model, dataset, itemset, device):
    model.eval()
    y = []
    y_hats = []
    dataloader = create_dataloader(dataset, itemset, device)

    for step, data in enumerate(dataloader):
        x = data.node_features["feat"]
        y.append(data.labels)
        y_hats.append(model(data.blocks, x))

    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(y),
        task="multiclass",
        num_classes=dataset.tasks[0].metadata["num_classes"],
    )


def train(model, dataset, device):
73
    # The first of two tasks in the dataset is node classification.
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    task = dataset.tasks[0]
    dataloader = create_dataloader(dataset, task.train_set, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

    for epoch in range(10):
        model.train()
        total_loss = 0
        ########################################################################
        # (HIGHLIGHT) Iterate over the dataloader and train the model with all
        # mini-batches.
        ########################################################################
        for step, data in enumerate(dataloader):
            # The features of sampled nodes.
            x = data.node_features["feat"]

            # The ground truth labels of the seed nodes.
            y = data.labels

            # Forward.
            y_hat = model(data.blocks, x)

            # Compute loss.
            loss = F.cross_entropy(y_hat, y)

            # Backward.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Evaluate the model.
        val_acc = evaluate(model, dataset, task.validation_set, device)
        test_acc = evaluate(model, dataset, task.test_set, device)
        print(
            f"Epoch {epoch:03d} | Loss {total_loss / (step + 1):.3f} | "
            f"Val Acc {val_acc.item():.3f} | Test Acc {test_acc.item():.3f}"
        )


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Training in {device} mode.")

    # Load and preprocess dataset.
    print("Loading data...")
120
    dataset = gb.BuiltinDataset("cora").load()
121

122
123
124
125
126
127
    # If a CUDA device is selected, we pin the graph and the features so that
    # the GPU can access them.
    if device == torch.device("cuda:0"):
        dataset.graph.pin_memory_()
        dataset.feature.pin_memory_()

128
129
130
131
132
133
134
    in_size = dataset.feature.size("node", None, "feat")[0]
    out_size = dataset.tasks[0].metadata["num_classes"]
    model = GCN(in_size, out_size).to(device)

    # Model training.
    print("Training...")
    train(model, dataset, device)