cluster_gcn.py 3.26 KB
Newer Older
1
2
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
3
4
5
import dgl
import dgl.nn as dglnn

6
import numpy as np
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
10
import torchmetrics.functional as MF
11
12
from ogb.nodeproppred import DglNodePropPredDataset

13
14
15
16
17

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
18
19
20
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
21
22
23
24
25
26
27
28
29
30
31
32
        self.dropout = nn.Dropout(0.5)

    def forward(self, sg, x):
        h = x
        for l, layer in enumerate(self.layers):
            h = layer(sg, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h


33
34
35
36
37
38
dataset = dgl.data.AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
graph = dataset[
    0
]  # already prepares ndata['label'/'train_mask'/'val_mask'/'test_mask']

model = SAGE(graph.ndata["feat"].shape[1], 256, dataset.num_classes).cuda()
39
40
41
42
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

num_partitions = 1000
sampler = dgl.dataloading.ClusterGCNSampler(
43
44
45
46
    graph,
    num_partitions,
    prefetch_ndata=["feat", "label", "train_mask", "val_mask", "test_mask"],
)
47
48
49
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler.
dataloader = dgl.dataloading.DataLoader(
50
51
52
53
54
55
56
57
58
59
    graph,
    torch.arange(num_partitions).to("cuda"),
    sampler,
    device="cuda",
    batch_size=100,
    shuffle=True,
    drop_last=False,
    num_workers=0,
    use_uva=True,
)
60
61
62
63
64
65

durations = []
for _ in range(10):
    t0 = time.time()
    model.train()
    for it, sg in enumerate(dataloader):
66
67
68
        x = sg.ndata["feat"]
        y = sg.ndata["label"]
        m = sg.ndata["train_mask"].bool()
69
70
71
72
73
74
75
76
        y_hat = model(sg, x)
        loss = F.cross_entropy(y_hat[m], y[m])
        opt.zero_grad()
        loss.backward()
        opt.step()
        if it % 20 == 0:
            acc = MF.accuracy(y_hat[m], y[m])
            mem = torch.cuda.max_memory_allocated() / 1000000
77
            print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
78
79
80
81
82
83
84
85
86
    tt = time.time()
    print(tt - t0)
    durations.append(tt - t0)

    model.eval()
    with torch.no_grad():
        val_preds, test_preds = [], []
        val_labels, test_labels = [], []
        for it, sg in enumerate(dataloader):
87
88
89
90
            x = sg.ndata["feat"]
            y = sg.ndata["label"]
            m_val = sg.ndata["val_mask"].bool()
            m_test = sg.ndata["test_mask"].bool()
91
92
93
94
95
96
97
98
99
100
101
            y_hat = model(sg, x)
            val_preds.append(y_hat[m_val])
            val_labels.append(y[m_val])
            test_preds.append(y_hat[m_test])
            test_labels.append(y[m_test])
        val_preds = torch.cat(val_preds, 0)
        val_labels = torch.cat(val_labels, 0)
        test_preds = torch.cat(test_preds, 0)
        test_labels = torch.cat(test_labels, 0)
        val_acc = MF.accuracy(val_preds, val_labels)
        test_acc = MF.accuracy(test_preds, test_labels)
102
        print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
103
104

print(np.mean(durations[4:]), np.std(durations[4:]))