Commit ec590171 authored by rusty1s's avatar rusty1s
Browse files

update training GIN script

parent a73bb262
import argparse import argparse
import torch import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau as ReduceLR
from torch.nn import Identity, Sequential, Linear, ReLU, BatchNorm1d from torch.nn import Identity, Sequential, Linear, ReLU, BatchNorm1d
from torch_sparse import SparseTensor
import torch_geometric.transforms as T import torch_geometric.transforms as T
from torch_geometric.nn import GINConv from torch_geometric.nn import GINConv
from torch_geometric.data import DataLoader from torch_geometric.data import DataLoader
from torch_geometric.datasets import GNNBenchmarkDataset as SBM from torch_geometric.datasets import GNNBenchmarkDataset as SBM
from torch_geometric_autoscale import get_data
from torch_geometric_autoscale import metis, permute
from torch_geometric_autoscale.models import ScalableGNN from torch_geometric_autoscale.models import ScalableGNN
from torch_geometric_autoscale import (get_data, SubgraphLoader, from torch_geometric_autoscale import SubgraphLoader, EvalSubgraphLoader
EvalSubgraphLoader)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True, parser.add_argument('--root', type=str, required=True,
...@@ -23,32 +26,33 @@ device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' ...@@ -23,32 +26,33 @@ device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='CLUSTER') data, in_channels, out_channels = get_data(args.root, name='CLUSTER')
train_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='train', # Pre-partition the graph using Metis:
pre_transform=T.ToSparseTensor()) perm, ptr = metis(data.adj_t, num_parts=10000, log=True)
val_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='val', data = permute(data, perm, log=True)
pre_transform=T.ToSparseTensor())
test_dataset = SBM(f'{args.root}/SBM', name='CLUSTER', split='test',
pre_transform=T.ToSparseTensor())
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)
ptr = [0]
for d in train_dataset: # Minimize inter-connectivity between batches:
ptr += [ptr[-1] + d.num_nodes // 2, ptr[-1] + d.num_nodes]
ptr = torch.tensor(ptr)
train_loader = SubgraphLoader(data, ptr, batch_size=256, shuffle=True, train_loader = SubgraphLoader(data, ptr, batch_size=256, shuffle=True,
num_workers=6, persistent_workers=True) num_workers=6, persistent_workers=True)
eval_loader = EvalSubgraphLoader(data, ptr, batch_size=256) eval_loader = EvalSubgraphLoader(data, ptr, batch_size=256)
# We use the regular PyTorch Geometric dataset for evaluation:
kwargs = {'name': 'CLUSTER', 'pre_transform': T.ToSparseTensor()}
val_dataset = SBM(f'{args.root}/SBM', split='val', **kwargs)
test_dataset = SBM(f'{args.root}/SBM', split='test', **kwargs)
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)
class GIN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels,
num_layers):
super(GIN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size=2, buffer_size=200000)
# We define our own GAS+GIN module:
class GIN(ScalableGNN):
def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int):
super().__init__(num_nodes, hidden_channels, num_layers, pool_size=2,
buffer_size=60000)
# pool_size determines the number of pinned CPU buffers
# buffer_size determines the size of pinned CPU buffers,
# i.e. the maximum number of out-of-mini-batch nodes
self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.lins = torch.nn.ModuleList() self.lins = torch.nn.ModuleList()
...@@ -64,46 +68,43 @@ class GIN(ScalableGNN): ...@@ -64,46 +68,43 @@ class GIN(ScalableGNN):
mlp = Sequential( mlp = Sequential(
Linear(hidden_channels, hidden_channels), Linear(hidden_channels, hidden_channels),
BatchNorm1d(hidden_channels, track_running_stats=False), BatchNorm1d(hidden_channels, track_running_stats=False),
ReLU(inplace=True), ReLU(),
Linear(hidden_channels, hidden_channels), Linear(hidden_channels, hidden_channels),
ReLU(), ReLU(),
) )
self.mlps.append(mlp) self.mlps.append(mlp)
def forward(self, x, adj_t, batch_size=None, n_id=None, offset=None, def forward(self, x: Tensor, adj_t: SparseTensor, *args):
count=None):
reg = 0
x = self.lins[0](x).relu_() x = self.lins[0](x).relu_()
for i, (conv, mlp, hist) in enumerate( reg = 0
zip(self.convs[:-1], self.mlps[:-1], self.histories)): it = zip(self.convs[:-1], self.mlps[:-1], self.histories)
for i, (conv, mlp, history) in enumerate(it):
h = conv((x, x[:adj_t.size(0)]), adj_t) h = conv((x, x[:adj_t.size(0)]), adj_t)
# Enforce Lipschitz continuity via regularization (part 1): # Regularize Lipschitz continuity via regularization (part 1):
if i > 0 and self.training: if i > 0 and self.training:
eps = 0.01 * torch.randn_like(h) approx = mlp(h + 0.1 * torch.randn_like(h))
approx = mlp(h + eps)
h = mlp(h) h = mlp(h)
# Enforce Lipschitz continuity via regularization (part 2): # Regularize Lipschitz continuity via regularization (part 2):
if i > 0 and self.training: if i > 0 and self.training:
diff = (h - approx).norm(dim=-1) diff = (h - approx).norm(dim=-1)
reg += diff.mean() / len(self.histories) reg += diff.mean() / len(self.histories)
h += x[:h.size(0)] h += x[:h.size(0)] # Simple skip-connection
x = self.push_and_pull(hist, h, batch_size, n_id, offset, count) x = self.push_and_pull(history, h, *args)
h = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t) h = self.convs[-1]((x, x[:adj_t.size(0)]), adj_t)
h = self.mlps[-1](h) h = self.mlps[-1](h)
h += x[:h.size(0)] h += x[:h.size(0)]
x = self.lins[1](h)
return self.lins[1](h), reg return x, reg
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer, x, adj_t, state): def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, state):
if layer == 0: if layer == 0:
x = self.lins[0](x).relu_() x = self.lins[0](x).relu_()
...@@ -118,7 +119,7 @@ class GIN(ScalableGNN): ...@@ -118,7 +119,7 @@ class GIN(ScalableGNN):
model = GIN( model = GIN(
num_nodes=train_dataset.data.num_nodes, num_nodes=data.num_nodes,
in_channels=in_channels, in_channels=in_channels,
hidden_channels=128, hidden_channels=128,
out_channels=out_channels, out_channels=out_channels,
...@@ -127,23 +128,20 @@ model = GIN( ...@@ -127,23 +128,20 @@ model = GIN(
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, scheduler = ReduceLR(optimizer, 'max', factor=0.5, patience=20, min_lr=1e-5)
min_lr=1e-5)
def train(model, loader, optimizer): def train(model, loader, optimizer):
model.train() model.train()
total_loss = total_examples = 0 total_loss = total_examples = 0
for batch, batch_size, n_id, offset, count in loader: for batch, *args in loader:
batch = batch.to(model.device) batch = batch.to(model.device)
optimizer.zero_grad() optimizer.zero_grad()
out, reg = model(batch.x, batch.adj_t, batch_size, n_id, offset, count) out, reg = model(batch.x, batch.adj_t, *args)
loss = criterion(out, batch.y[:batch_size]) + reg loss = criterion(out, batch.y[:out.size(0)]) + reg
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += float(loss) * int(out.size(0)) total_loss += float(loss) * int(out.size(0))
total_examples += int(out.size(0)) total_examples += int(out.size(0))
...@@ -172,6 +170,7 @@ def mini_test(model, loader, y): ...@@ -172,6 +170,7 @@ def mini_test(model, loader, y):
mini_test(model, eval_loader, data.y) # Fill history. mini_test(model, eval_loader, data.y) # Fill history.
for epoch in range(1, 151): for epoch in range(1, 151):
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
loss = train(model, train_loader, optimizer) loss = train(model, train_loader, optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment