Commit a4f271f7 authored by rusty1s's avatar rusty1s
Browse files

train gin

parent 0afed7cc
...@@ -60,9 +60,3 @@ python setup.py install ...@@ -60,9 +60,3 @@ python setup.py install
* **`large_benchmark/`** includes experiments to evaluate GAS performance on *large-scale* graphs * **`large_benchmark/`** includes experiments to evaluate GAS performance on *large-scale* graphs
We use [**Hydra**](https://hydra.cc/) to manage hyperparameter configurations. We use [**Hydra**](https://hydra.cc/) to manage hyperparameter configurations.
## Running tests
```
python setup.py test
```
...@@ -27,6 +27,7 @@ data = permute(data, perm, log=True) ...@@ -27,6 +27,7 @@ data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True) loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)
# Make use of the pre-defined GCN+GAS model:
model = GCN( model = GCN(
num_nodes=data.num_nodes, num_nodes=data.num_nodes,
in_channels=in_channels, in_channels=in_channels,
...@@ -46,11 +47,11 @@ optimizer = torch.optim.Adam([ ...@@ -46,11 +47,11 @@ optimizer = torch.optim.Adam([
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
def train(data, model, loader, optimizer): def train(model, loader, optimizer):
model.train() model.train()
for batch, batch_size, n_id, offset, count in loader: for batch, batch_size, n_id, offset, count in loader:
batch = batch.to(device) batch = batch.to(model.device)
train_mask = batch.train_mask[:batch_size] train_mask = batch.train_mask[:batch_size]
optimizer.zero_grad() optimizer.zero_grad()
...@@ -76,10 +77,10 @@ def test(data, model): ...@@ -76,10 +77,10 @@ def test(data, model):
test(data, model) # Fill history. test(data, model) # Fill history.
best_val_acc = test_acc = 0 best_val_acc = test_acc = 0
for epoch in range(1, 201): for epoch in range(1, 201):
train(data, model, loader, optimizer) train(model, loader, optimizer)
train_acc, val_acc, tmp_test_acc = test(data, model) train_acc, val_acc, tmp_test_acc = test(data, model)
if val_acc > best_val_acc: if val_acc > best_val_acc:
best_val_acc = val_acc best_val_acc = val_acc
test_acc = tmp_test_acc test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}' print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}') f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}')
import argparse
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import Identity, Sequential, Linear, ReLU, BatchNorm1d
import torch_geometric.transforms as T
from torch_geometric.nn import GINConv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset as SBM
from torch_geometric_autoscale.models import ScalableGNN
from torch_geometric_autoscale import (get_data, SubgraphLoader,
EvalSubgraphLoader)
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()
torch.manual_seed(123)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='CLUSTER')
train_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='train',
pre_transform=T.ToSparseTensor())
val_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='val',
pre_transform=T.ToSparseTensor())
test_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='test',
pre_transform=T.ToSparseTensor())
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)
ptr = [0]
for d in train_dataset: # Minimize inter-connectivity between batches:
ptr += [ptr[-1] + d.num_nodes // 2, ptr[-1] + d.num_nodes]
ptr = torch.tensor(ptr)
train_loader = SubgraphLoader(data, ptr, batch_size=256, shuffle=True,
num_workers=6, persistent_workers=True)
eval_loader = EvalSubgraphLoader(data, ptr, batch_size=256)
class GIN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels,
num_layers):
super(GIN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size=2, buffer_size=200000)
self.out_channels = out_channels
self.lins = torch.nn.ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = torch.nn.ModuleList()
for i in range(num_layers):
self.convs.append(GINConv(Identity(), train_eps=True))
self.mlps = torch.nn.ModuleList()
for _ in range(num_layers):
mlp = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels, track_running_stats=False),
ReLU(inplace=True),
Linear(hidden_channels, hidden_channels),
ReLU(),
)
self.mlps.append(mlp)
def forward(self, x, adj_t, batch_size=None, n_id=None, offset=None,
count=None):
reg = 0
x = self.lins[0](x).relu_()
for i, (conv, mlp, hist) in enumerate(
zip(self.convs[:-1], self.mlps[:-1], self.histories)):
h = conv((x, x[:adj_t.size(0)]), adj_t)
# Enforce Lipschitz continuity via regularization (part 1):
if i > 0 and self.training:
eps = 0.01 * torch.randn_like(h)
approx = mlp(h + eps)
h = mlp(h)
# Enforce Lipschitz continuity via regularization (part 2):
if i > 0 and self.training:
diff = (h - approx).norm(dim=-1)
reg += diff.mean() / len(self.histories)
h += x[:h.size(0)]
x = self.push_and_pull(hist, h, batch_size, n_id, offset, count)
h = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
h = self.mlps[-1](h)
h += x[:h.size(0)]
return self.lins[1](h), reg
@torch.no_grad()
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
x = self.lins[0](x).relu_()
h = self.convs[layer]((x, x[:adj_t.size(0)]), adj_t)
h = self.mlps[layer](h)
h += x[:h.size(0)]
if layer == self.num_layers - 1:
h = self.lins[1](h)
return h
model = GIN(
num_nodes=train_dataset.data.num_nodes,
in_channels=in_channels,
hidden_channels=128,
out_channels=out_channels,
num_layers=4,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20,
min_lr=1e-5)
def train(model, loader, optimizer):
model.train()
total_loss = total_examples = 0
for batch, batch_size, n_id, offset, count in loader:
batch = batch.to(model.device)
optimizer.zero_grad()
out, reg = model(batch.x, batch.adj_t, batch_size, n_id, offset, count)
loss = criterion(out, batch.y[:batch_size]) + reg
loss.backward()
optimizer.step()
total_loss += float(loss) * int(out.size(0))
total_examples += int(out.size(0))
return total_loss / total_examples
@torch.no_grad()
def full_test(model, loader):
model.eval()
total_correct = total_examples = 0
for batch in loader:
batch = batch.to(device)
out, _ = model(batch.x, batch.adj_t)
total_correct += int((out.argmax(dim=-1) == batch.y).sum())
total_examples += out.size(0)
return total_correct / total_examples
@torch.no_grad()
def mini_test(model, loader, y):
model.eval()
out = model(loader=loader)
return int((out.argmax(dim=-1) == y).sum()) / y.size(0)
mini_test(model, eval_loader, data.y) # Fill history.
for epoch in range(1, 151):
lr = optimizer.param_groups[0]['lr']
loss = train(model, train_loader, optimizer)
train_acc = mini_test(model, eval_loader, data.y)
val_acc = full_test(model, val_loader)
test_acc = full_test(model, test_loader)
scheduler.step(val_acc)
print(f'Epoch: {epoch:03d}, LR: {lr:.5f} Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
...@@ -3,7 +3,6 @@ from .gcn import GCN ...@@ -3,7 +3,6 @@ from .gcn import GCN
# from .gat import GAT # from .gat import GAT
# from .appnp import APPNP # from .appnp import APPNP
# from .gcn2 import GCN2 # from .gcn2 import GCN2
# from .gin import GIN
# from .pna import PNA # from .pna import PNA
# from .pna_jk import PNA_JK # from .pna_jk import PNA_JK
...@@ -13,7 +12,6 @@ __all__ = [ ...@@ -13,7 +12,6 @@ __all__ = [
# 'GAT', # 'GAT',
# 'APPNP', # 'APPNP',
# 'GCN2', # 'GCN2',
# 'GIN',
# 'PNA', # 'PNA',
# 'PNA_JK', # 'PNA_JK',
] ]
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Identity
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_sparse import SparseTensor
from torch_geometric.nn import GINConv
from torch_geometric.nn.inits import reset
from .base import HistoryGNN
class GIN(HistoryGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, residual: bool = False,
dropout: float = 0.0, device=None, dtype=None):
super(GIN, self).__init__(num_nodes, hidden_channels, num_layers,
device, dtype)
self.in_channels = in_channels
self.out_channels = out_channels
self.residual = residual
self.dropout = dropout
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for _ in range(num_layers):
conv = GINConv(nn=Identity(), train_eps=True)
self.convs.append(conv)
self.post_nns = ModuleList()
for i in range(num_layers):
post_nn = Sequential(
Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels, track_running_stats=False),
ReLU(inplace=True),
Linear(hidden_channels, hidden_channels),
ReLU(inplace=True),
)
self.post_nns.append(post_nn)
def reset_parameters(self):
super(GIN, self).reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for post_nn in self.post_nns:
reset(post_nn)
for lin in self.lins:
lin.reset_parameters()
def forward(self, x: Tensor, adj_t: SparseTensor,
batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None) -> Tensor:
x = self.lins[0](x).relu()
for conv, post_nn, history in zip(self.convs[:-1], self.post_nns[:-1],
self.histories):
if batch_size is not None:
h = torch.zeros_like(x)
h[:batch_size] = post_nn(conv(x, adj_t)[:batch_size])
else:
h = post_nn(conv(x, adj_t))
x = h.add_(x) if self.residual else h
x = self.push_and_pull(history, x, batch_size, n_id)
x = F.dropout(x, p=self.dropout, training=self.training)
if batch_size is not None:
h = self.post_nns[-1](self.convs[-1](x, adj_t)[:batch_size])
x = x[:batch_size]
else:
h = self.post_nns[-1](self.convs[-1](x, adj_t))
x = h.add_(x) if self.residual else h
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[1](x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment