"vscode:/vscode.git/clone" did not exist on "df40b11d03b63d6d746c3c9c532fb2dda50f56b9"
train_gin.py 6.11 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import argparse

import torch
rusty1s's avatar
rusty1s committed
4
5
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau as ReduceLR
rusty1s's avatar
rusty1s committed
6
from torch.nn import Identity, Sequential, Linear, ReLU, BatchNorm1d
rusty1s's avatar
rusty1s committed
7
from torch_sparse import SparseTensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
import torch_geometric.transforms as T
from torch_geometric.nn import GINConv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset as SBM

rusty1s's avatar
rusty1s committed
13
14
from torch_geometric_autoscale import get_data
from torch_geometric_autoscale import metis, permute
rusty1s's avatar
rusty1s committed
15
from torch_geometric_autoscale.models import ScalableGNN
rusty1s's avatar
rusty1s committed
16
from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
                    help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()

torch.manual_seed(123)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

data, in_channels, out_channels = get_data(args.root, name='CLUSTER')

rusty1s's avatar
rusty1s committed
29
30
31
# Pre-partition the graph using Metis:
perm, ptr = metis(data.adj_t, num_parts=10000, log=True)
data = permute(data, perm, log=True)
rusty1s's avatar
rusty1s committed
32
33
34
35
36

train_loader = SubgraphLoader(data, ptr, batch_size=256, shuffle=True,
                              num_workers=6, persistent_workers=True)
eval_loader = EvalSubgraphLoader(data, ptr, batch_size=256)

rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
# We use the regular PyTorch Geometric dataset for evaluation:
kwargs = {'name': 'CLUSTER', 'pre_transform': T.ToSparseTensor()}
val_dataset = SBM(f'{args.root}/SBM', split='val', **kwargs)
test_dataset = SBM(f'{args.root}/SBM', split='test', **kwargs)
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)
rusty1s's avatar
rusty1s committed
43
44


rusty1s's avatar
rusty1s committed
45
46
47
48
49
50
51
52
53
54
55
# We define our own GAS+GIN module:
class GIN(ScalableGNN):
    def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
                 out_channels: int, num_layers: int):
        super().__init__(num_nodes, hidden_channels, num_layers, pool_size=2,
                         buffer_size=60000)
        # pool_size determines the number of pinned CPU buffers
        # buffer_size determines the size of pinned CPU buffers,
        #             i.e. the maximum number of out-of-mini-batch nodes

        self.in_channels = in_channels
rusty1s's avatar
rusty1s committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        self.out_channels = out_channels

        self.lins = torch.nn.ModuleList()
        self.lins.append(Linear(in_channels, hidden_channels))
        self.lins.append(Linear(hidden_channels, out_channels))

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(GINConv(Identity(), train_eps=True))

        self.mlps = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = Sequential(
                Linear(hidden_channels, hidden_channels),
                BatchNorm1d(hidden_channels, track_running_stats=False),
rusty1s's avatar
rusty1s committed
71
                ReLU(),
rusty1s's avatar
rusty1s committed
72
73
74
75
76
                Linear(hidden_channels, hidden_channels),
                ReLU(),
            )
            self.mlps.append(mlp)

rusty1s's avatar
rusty1s committed
77
    def forward(self, x: Tensor, adj_t: SparseTensor, *args):
rusty1s's avatar
rusty1s committed
78
79
        x = self.lins[0](x).relu_()

rusty1s's avatar
rusty1s committed
80
81
82
        reg = 0
        it = zip(self.convs[:-1], self.mlps[:-1], self.histories)
        for i, (conv, mlp, history) in enumerate(it):
rusty1s's avatar
rusty1s committed
83
84
            h = conv((x, x[:adj_t.size(0)]), adj_t)

rusty1s's avatar
rusty1s committed
85
            # Regularize Lipschitz continuity via regularization (part 1):
rusty1s's avatar
rusty1s committed
86
            if i > 0 and self.training:
rusty1s's avatar
rusty1s committed
87
                approx = mlp(h + 0.1 * torch.randn_like(h))
rusty1s's avatar
rusty1s committed
88
89
90

            h = mlp(h)

rusty1s's avatar
rusty1s committed
91
            # Regularize Lipschitz continuity via regularization (part 2):
rusty1s's avatar
rusty1s committed
92
93
94
95
            if i > 0 and self.training:
                diff = (h - approx).norm(dim=-1)
                reg += diff.mean() / len(self.histories)

rusty1s's avatar
rusty1s committed
96
97
            h += x[:h.size(0)]  # Simple skip-connection
            x = self.push_and_pull(history, h, *args)
rusty1s's avatar
rusty1s committed
98
99
100
101

        h = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
        h = self.mlps[-1](h)
        h += x[:h.size(0)]
rusty1s's avatar
rusty1s committed
102
        x = self.lins[1](h)
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
        return x, reg
rusty1s's avatar
rusty1s committed
105
106

    @torch.no_grad()
rusty1s's avatar
rusty1s committed
107
    def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, state):
rusty1s's avatar
rusty1s committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if layer == 0:
            x = self.lins[0](x).relu_()

        h = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
        h = self.mlps[layer](h)
        h += x[:h.size(0)]

        if layer == self.num_layers - 1:
            h = self.lins[1](h)

        return h


model = GIN(
rusty1s's avatar
rusty1s committed
122
    num_nodes=data.num_nodes,
rusty1s's avatar
rusty1s committed
123
124
125
126
127
128
129
130
    in_channels=in_channels,
    hidden_channels=128,
    out_channels=out_channels,
    num_layers=4,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
rusty1s's avatar
rusty1s committed
131
scheduler = ReduceLR(optimizer, 'max', factor=0.5, patience=20, min_lr=1e-5)
rusty1s's avatar
rusty1s committed
132
133
134
135
136
137


def train(model, loader, optimizer):
    model.train()

    total_loss = total_examples = 0
rusty1s's avatar
rusty1s committed
138
    for batch, *args in loader:
rusty1s's avatar
rusty1s committed
139
140
        batch = batch.to(model.device)
        optimizer.zero_grad()
rusty1s's avatar
rusty1s committed
141
142
        out, reg = model(batch.x, batch.adj_t, *args)
        loss = criterion(out, batch.y[:out.size(0)]) + reg
rusty1s's avatar
rusty1s committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * int(out.size(0))
        total_examples += int(out.size(0))

    return total_loss / total_examples


@torch.no_grad()
def full_test(model, loader):
    model.eval()

    total_correct = total_examples = 0
    for batch in loader:
        batch = batch.to(device)
        out, _ = model(batch.x, batch.adj_t)
        total_correct += int((out.argmax(dim=-1) == batch.y).sum())
        total_examples += out.size(0)

    return total_correct / total_examples


@torch.no_grad()
def mini_test(model, loader, y):
    model.eval()
    out = model(loader=loader)
    return int((out.argmax(dim=-1) == y).sum()) / y.size(0)


mini_test(model, eval_loader, data.y)  # Fill history.
rusty1s's avatar
rusty1s committed
173

rusty1s's avatar
rusty1s committed
174
175
176
177
178
179
180
181
182
for epoch in range(1, 151):
    lr = optimizer.param_groups[0]['lr']
    loss = train(model, train_loader, optimizer)
    train_acc = mini_test(model, eval_loader, data.y)
    val_acc = full_test(model, val_loader)
    test_acc = full_test(model, test_loader)
    scheduler.step(val_acc)
    print(f'Epoch: {epoch:03d}, LR: {lr:.5f} Loss: {loss:.4f}, '
          f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')