train_gcn.py 2.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import argparse

import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from torch_geometric_autoscale.models import GCN
rusty1s's avatar
typo  
rusty1s committed
7
from torch_geometric_autoscale import metis, permute, SubgraphLoader
rusty1s's avatar
rusty1s committed
8
from torch_geometric_autoscale import get_data, compute_micro_f1
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
                    help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()

torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

rusty1s's avatar
rusty1s committed
19
data, in_channels, out_channels = get_data(args.root, name='cora')
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29

# Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)

# Pre-partition the graph using Metis:
perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)

loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)

rusty1s's avatar
rusty1s committed
30
# Make use of the pre-defined GCN+GAS model:
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
model = GCN(
    num_nodes=data.num_nodes,
    in_channels=in_channels,
    hidden_channels=16,
    out_channels=out_channels,
    num_layers=2,
    dropout=0.5,
    drop_input=True,
rusty1s's avatar
rusty1s committed
39
40
    pool_size=1,  # Number of pinned CPU buffers
    buffer_size=500,  # Size of pinned CPU buffers (max #out-of-batch nodes)
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
48
49
).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.reg_modules.parameters(), weight_decay=5e-4),
    dict(params=model.nonreg_modules.parameters(), weight_decay=0)
], lr=0.01)
criterion = torch.nn.CrossEntropyLoss()


rusty1s's avatar
rusty1s committed
50
def train(model, loader, optimizer):
rusty1s's avatar
rusty1s committed
51
52
    model.train()

rusty1s's avatar
rusty1s committed
53
    for batch, *args in loader:
rusty1s's avatar
rusty1s committed
54
        batch = batch.to(model.device)
rusty1s's avatar
rusty1s committed
55
        optimizer.zero_grad()
rusty1s's avatar
rusty1s committed
56
57
58
        out = model(batch.x, batch.adj_t, *args)
        train_mask = batch.train_mask[:out.size(0)]
        loss = criterion(out[train_mask], batch.y[:out.size(0)][train_mask])
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()


@torch.no_grad()
rusty1s's avatar
rusty1s committed
65
def test(model, data):
rusty1s's avatar
rusty1s committed
66
67
    model.eval()

rusty1s's avatar
rusty1s committed
68
    # Full-batch inference since the graph is small
rusty1s's avatar
rusty1s committed
69
    out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
rusty1s's avatar
rusty1s committed
70
71
72
    train_acc = compute_micro_f1(out, data.y, data.train_mask)
    val_acc = compute_micro_f1(out, data.y, data.val_mask)
    test_acc = compute_micro_f1(out, data.y, data.test_mask)
rusty1s's avatar
rusty1s committed
73
74
75
76

    return train_acc, val_acc, test_acc


rusty1s's avatar
rusty1s committed
77
78
test(model, data)  # Fill the history.

rusty1s's avatar
rusty1s committed
79
80
best_val_acc = test_acc = 0
for epoch in range(1, 201):
rusty1s's avatar
rusty1s committed
81
    train(model, loader, optimizer)
rusty1s's avatar
rusty1s committed
82
    train_acc, val_acc, tmp_test_acc = test(model, data)
rusty1s's avatar
rusty1s committed
83
84
85
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
rusty1s's avatar
rusty1s committed
86
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
rusty1s's avatar
rusty1s committed
87
          f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}')