train_gin.py 6.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import argparse

import torch
rusty1s's avatar
rusty1s committed
4
5
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau as ReduceLR
rusty1s's avatar
rusty1s committed
6
from torch.nn import Identity, Sequential, Linear, ReLU, BatchNorm1d
rusty1s's avatar
rusty1s committed
7
from torch_sparse import SparseTensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
import torch_geometric.transforms as T
from torch_geometric.nn import GINConv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset as SBM

rusty1s's avatar
rusty1s committed
13
14
from torch_geometric_autoscale import get_data
from torch_geometric_autoscale import metis, permute
rusty1s's avatar
rusty1s committed
15
from torch_geometric_autoscale.models import ScalableGNN
rusty1s's avatar
rusty1s committed
16
from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
                    help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()

torch.manual_seed(123)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

data, in_channels, out_channels = get_data(args.root, name='CLUSTER')

rusty1s's avatar
rusty1s committed
29
30
31
# Pre-partition the graph using Metis:
perm, ptr = metis(data.adj_t, num_parts=10000, log=True)
data = permute(data, perm, log=True)
rusty1s's avatar
rusty1s committed
32
33
34
35
36

train_loader = SubgraphLoader(data, ptr, batch_size=256, shuffle=True,
                              num_workers=6, persistent_workers=True)
eval_loader = EvalSubgraphLoader(data, ptr, batch_size=256)

rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
# We use the regular PyTorch Geometric dataset for evaluation:
kwargs = {'name': 'CLUSTER', 'pre_transform': T.ToSparseTensor()}
val_dataset = SBM(f'{args.root}/SBM', split='val', **kwargs)
test_dataset = SBM(f'{args.root}/SBM', split='test', **kwargs)
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)
rusty1s's avatar
rusty1s committed
43
44


rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
# We define our own GAS+GIN module:
class GIN(ScalableGNN):
    def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
                 out_channels: int, num_layers: int):
        super().__init__(num_nodes, hidden_channels, num_layers, pool_size=2,
                         buffer_size=60000)
        # pool_size determines the number of pinned CPU buffers
        # buffer_size determines the size of pinned CPU buffers,
        #             i.e. the maximum number of out-of-mini-batch nodes

        self.in_channels = in_channels
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        self.out_channels = out_channels

        self.lins = torch.nn.ModuleList()
        self.lins.append(Linear(in_channels, hidden_channels))
        self.lins.append(Linear(hidden_channels, out_channels))

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(GINConv(Identity(), train_eps=True))

        self.mlps = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = Sequential(
                Linear(hidden_channels, hidden_channels),
                BatchNorm1d(hidden_channels, track_running_stats=False),
rusty1s's avatar
rusty1s committed
71
                ReLU(),
rusty1s's avatar
rusty1s committed
72
73
74
75
76
                Linear(hidden_channels, hidden_channels),
                ReLU(),
            )
            self.mlps.append(mlp)

rusty1s's avatar
rusty1s committed
77
    def forward(self, x: Tensor, adj_t: SparseTensor, *args):
rusty1s's avatar
rusty1s committed
78
79
        x = self.lins[0](x).relu_()

rusty1s's avatar
rusty1s committed
80
81
82
        reg = 0
        it = zip(self.convs[:-1], self.mlps[:-1], self.histories)
        for i, (conv, mlp, history) in enumerate(it):
rusty1s's avatar
rusty1s committed
83
84
            h = conv((x, x[:adj_t.size(0)]), adj_t)

rusty1s's avatar
rusty1s committed
85
            # Regularize Lipschitz continuity via regularization (part 1):
rusty1s's avatar
rusty1s committed
86
            if i > 0 and self.training:
rusty1s's avatar
rusty1s committed
87
                approx = mlp(h + 0.1 * torch.randn_like(h))
rusty1s's avatar
rusty1s committed
88
89
90

            h = mlp(h)

rusty1s's avatar
rusty1s committed
91
            # Regularize Lipschitz continuity via regularization (part 2):
rusty1s's avatar
rusty1s committed
92
93
94
95
            if i > 0 and self.training:
                diff = (h - approx).norm(dim=-1)
                reg += diff.mean() / len(self.histories)

rusty1s's avatar
rusty1s committed
96
97
            h += x[:h.size(0)]  # Simple skip-connection
            x = self.push_and_pull(history, h, *args)
rusty1s's avatar
rusty1s committed
98
99
100
101

        h = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
        h = self.mlps[-1](h)
        h += x[:h.size(0)]
rusty1s's avatar
rusty1s committed
102
        x = self.lins[1](h)
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
        return x, reg
rusty1s's avatar
rusty1s committed
105
106

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
107
    def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, state):
rusty1s's avatar
rusty1s committed
108
        # Perform layer-wise inference:
rusty1s's avatar
rusty1s committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if layer == 0:
            x = self.lins[0](x).relu_()

        h = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
        h = self.mlps[layer](h)
        h += x[:h.size(0)]

        if layer == self.num_layers - 1:
            h = self.lins[1](h)

        return h


model = GIN(
rusty1s's avatar
rusty1s committed
123
    num_nodes=data.num_nodes,
rusty1s's avatar
rusty1s committed
124
125
126
127
128
129
130
131
    in_channels=in_channels,
    hidden_channels=128,
    out_channels=out_channels,
    num_layers=4,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
rusty1s's avatar
rusty1s committed
132
scheduler = ReduceLR(optimizer, 'max', factor=0.5, patience=20, min_lr=1e-5)
rusty1s's avatar
rusty1s committed
133
134
135
136
137
138


def train(model, loader, optimizer):
    model.train()

    total_loss = total_examples = 0
rusty1s's avatar
rusty1s committed
139
    for batch, *args in loader:
rusty1s's avatar
rusty1s committed
140
141
        batch = batch.to(model.device)
        optimizer.zero_grad()
rusty1s's avatar
rusty1s committed
142
143
        out, reg = model(batch.x, batch.adj_t, *args)
        loss = criterion(out, batch.y[:out.size(0)]) + reg
rusty1s's avatar
rusty1s committed
144
145
146
147
148
149
150
151
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * int(out.size(0))
        total_examples += int(out.size(0))

    return total_loss / total_examples


rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
158
@torch.no_grad()
def mini_test(model, loader, y):
    model.eval()
    out = model(loader=loader)
    return int((out.argmax(dim=-1) == y).sum()) / y.size(0)


rusty1s's avatar
rusty1s committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@torch.no_grad()
def full_test(model, loader):
    model.eval()

    total_correct = total_examples = 0
    for batch in loader:
        batch = batch.to(device)
        out, _ = model(batch.x, batch.adj_t)
        total_correct += int((out.argmax(dim=-1) == batch.y).sum())
        total_examples += out.size(0)

    return total_correct / total_examples


mini_test(model, eval_loader, data.y)  # Fill history.
rusty1s's avatar
rusty1s committed
174

rusty1s's avatar
rusty1s committed
175
176
177
178
179
180
181
182
183
for epoch in range(1, 151):
    lr = optimizer.param_groups[0]['lr']
    loss = train(model, train_loader, optimizer)
    train_acc = mini_test(model, eval_loader, data.y)
    val_acc = full_test(model, val_loader)
    test_acc = full_test(model, test_loader)
    scheduler.step(val_acc)
    print(f'Epoch: {epoch:03d}, LR: {lr:.5f} Loss: {loss:.4f}, '
          f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')