Commit 544ce0ac authored by rusty1s's avatar rusty1s
Browse files

add large benchmark

parent 36fd4ffe
# Benchmark on Large-Scale Graphs
```
python main.py model=pna dataset=flickr root=/tmp/datasets device=0 log_every=1
```
You can choose between the following models and datasets:
* **Models:** `gcn`, `gcn2`, `pna`, `pna_jk`
* **Datasets:** `reddit`, `ppi`, `flickr`, `yelp`, `arxiv`, `products`
defaults:
- model: pna
- dataset: flickr
device: 0
root: '/tmp/datasets'
log_every: 1
# @package _group_
name: arxiv
# @package _group_
name: flickr
# @package _group_
name: ppi
# @package _group_
name: products
# @package _group_
name: reddit
# @package _group_
name: yelp
# @package _group_
name: GCN
norm: true
loop: true
params:
reddit:
architecture:
num_layers: 2
hidden_channels: 256
dropout: 0.5
drop_input: false
batch_norm: false
residual: false
num_parts: 200
batch_size: 100
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: none
epochs: 400
ppi:
architecture:
num_layers: 2
hidden_channels: 1024
dropout: 0.0
drop_input: false
batch_norm: true
residual: true
linear: true
num_parts: 20
batch_size: 2
max_steps: 10
pool_size: 2
num_workers: 0
lr: 0.005
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: null
epochs: 1000
flickr:
architecture:
num_layers: 2
hidden_channels: 256
dropout: 0.3
drop_input: true
batch_norm: true
residual: false
num_parts: 24
batch_size: 12
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 400
yelp:
architecture:
num_layers: 2
hidden_channels: 512
dropout: 0.0
drop_input: false
batch_norm: false
residual: true
linear: false
num_parts: 40
batch_size: 5
max_steps: 4
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 500
arxiv:
architecture:
num_layers: 3
hidden_channels: 256
dropout: 0.5
drop_input: false
batch_norm: true
residual: false
num_parts: 80
batch_size: 40
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: none
epochs: 300
runs: 1
products:
architecture:
num_layers: 3
hidden_channels: 256
dropout: 0.3
drop_input: false
batch_norm: false
residual: false
num_parts: 7
batch_size: 1
max_steps: 4
pool_size: 2
num_workers: 6
lr: 0.005
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 350
runs: 1
# @package _group_
name: GCN2
norm: true
loop: true
params:
reddit:
architecture:
num_layers: 4
hidden_channels: 256
dropout: 0.5
drop_input: true
batch_norm: true
residual: false
shared_weights: false
alpha: 0.1
theta: 0.5
num_parts: 200
batch_size: 100
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: null
epochs: 400
ppi:
architecture:
num_layers: 9
hidden_channels: 2048
dropout: 0.2
drop_input: true
batch_norm: false
residual: true
shared_weights: false
alpha: 0.5
theta: 1.0
num_parts: 20
batch_size: 2
max_steps: 10
pool_size: 2
num_workers: 0
lr: 0.001
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: 1.0
epochs: 2000
arxiv:
architecture:
num_layers: 4
hidden_channels: 256
dropout: 0.3
drop_input: false
batch_norm: true
residual: false
shared_weights: true
alpha: 0.2
theta: 0.5
num_parts: 40
batch_size: 20
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: null
epochs: 500
flickr:
architecture:
num_layers: 8
hidden_channels: 256
dropout: 0.5
drop_input: true
batch_norm: true
residual: false
shared_weights: false
alpha: 0.1
theta: 0.5
num_parts: 24
batch_size: 12
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 400
yelp:
architecture:
num_layers: 2
hidden_channels: 512
dropout: 0.0
drop_input: false
batch_norm: false
residual: false
shared_weights: false
alpha: 0.2
theta: 0.5
num_parts: 40
batch_size: 5
max_steps: 4
pool_size: 2
num_workers: 0
lr: 0.01
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 500
# @package _group_
name: PNA
norm: false
loop: false
params:
ppi:
architecture:
num_layers: 5
hidden_channels: 2048
aggregators: ['mean']
scalers: ['identity', 'amplification']
dropout: 0.2
drop_input: true
batch_norm: true
residual: true
num_parts: 20
batch_size: 2
max_steps: 10
pool_size: 2
num_workers: 0
lr: 0.001
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: 1.0
epochs: 300
arxiv:
architecture:
num_layers: 3
hidden_channels: 256
aggregators: ['mean']
scalers: ['identity', 'amplification']
dropout: 0.5
drop_input: false
batch_norm: true
residual: false
num_parts: 40
batch_size: 20
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.005
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: null
epochs: 500
flickr:
architecture:
num_layers: 4
hidden_channels: 64
aggregators: ['mean', 'max']
scalers: ['identity', 'amplification']
dropout: 0.5
drop_input: true
batch_norm: true
residual: false
num_parts: 24
batch_size: 12
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.005
reg_weight_decay: 0
nonreg_weight_decay: 0
grad_norm: null
epochs: 500
yelp:
architecture:
num_layers: 3
hidden_channels: 512
aggregators: ['mean']
scalers: ['identity', 'amplification']
dropout: 0.0
drop_input: false
batch_norm: false
residual: true
num_parts: 40
batch_size: 5
max_steps: 4
pool_size: 2
num_workers: 0
lr: 0.005
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: 1.0
epochs: 400
# @package _group_
name: PNA_JK
norm: false
loop: false
params:
reddit:
architecture:
num_layers: 3
hidden_channels: 128
aggregators: ['mean', 'max']
scalers: ['identity', 'amplification']
dropout: 0.5
drop_input: true
batch_norm: true
residual: true
num_parts: 200
batch_size: 100
max_steps: 2
pool_size: 2
num_workers: 0
lr: 0.005
reg_weight_decay: 0.0
nonreg_weight_decay: 0.0
grad_norm: 1.0
epochs: 400
import time
import hydra
from omegaconf import OmegaConf
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, EvalSubgraphLoader,
models, compute_acc)
from torch_geometric_autoscale.data import get_ppi
torch.manual_seed(123)
def mini_train(model, loader, criterion, optimizer, max_steps, grad_norm=None):
model.train()
total_loss = total_examples = 0
for i, (batch, batch_size, n_id, offset, count) in enumerate(loader):
x = batch.x.to(model.device)
adj_t = batch.adj_t.to(model.device)
y = batch.y[:batch_size].to(model.device)
train_mask = batch.train_mask[:batch_size].to(model.device)
if train_mask.sum() == 0:
continue
optimizer.zero_grad()
out = model(x, adj_t, batch_size, n_id, offset, count)
loss = criterion(out[train_mask], y[train_mask])
loss.backward()
if grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
optimizer.step()
total_loss += float(loss) * int(train_mask.sum())
total_examples += int(train_mask.sum())
if (i + 1) >= max_steps and (i + 1) < len(loader):
break
return total_loss / total_examples
@torch.no_grad()
def full_test(model, data):
model.eval()
return model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
@torch.no_grad()
def mini_test(model, loader):
model.eval()
return model(loader=loader)
@hydra.main(config_path='conf', config_name='config')
def main(conf):
conf.model.params = conf.model.params[conf.dataset.name]
params = conf.model.params
print(OmegaConf.to_yaml(conf))
grad_norm = None if isinstance(params.grad_norm, str) else params.grad_norm
device = f'cuda:{conf.device}' if torch.cuda.is_available() else 'cpu'
t = time.perf_counter()
print('Loading data...', end=' ', flush=True)
data, in_channels, out_channels = get_data(conf.root, conf.dataset.name)
print(f'Done! [{time.perf_counter() - t:.2f}s]')
perm, ptr = metis(data.adj_t, num_parts=params.num_parts, log=True)
data = permute(data, perm, log=True)
if conf.model.loop:
t = time.perf_counter()
print('Adding self-loops...', end=' ', flush=True)
data.adj_t = data.adj_t.set_diag()
print(f'Done! [{time.perf_counter() - t:.2f}s]')
if conf.model.norm:
t = time.perf_counter()
print('Normalizing data...', end=' ', flush=True)
data.adj_t = gcn_norm(data.adj_t, add_self_loops=False)
print(f'Done! [{time.perf_counter() - t:.2f}s]')
if data.y.dim() == 1:
criterion = torch.nn.CrossEntropyLoss()
else:
criterion = torch.nn.BCEWithLogitsLoss()
train_loader = SubgraphLoader(data, ptr, batch_size=params.batch_size,
shuffle=True, num_workers=params.num_workers,
persistent_workers=params.num_workers > 0)
eval_loader = EvalSubgraphLoader(data, ptr,
batch_size=params['batch_size'])
if conf.dataset.name == 'ppi':
val_data, _, _ = get_ppi(conf.root, split='val')
test_data, _, _ = get_ppi(conf.root, split='test')
if conf.model.loop:
val_data.adj_t = val_data.adj_t.set_diag()
test_data.adj_t = test_data.adj_t.set_diag()
if conf.model.norm:
val_data.adj_t = gcn_norm(val_data.adj_t, add_self_loops=False)
test_data.adj_t = gcn_norm(test_data.adj_t, add_self_loops=False)
t = time.perf_counter()
print('Calculating buffer size...', end=' ', flush=True)
buffer_size = max([n_id.numel() for _, _, n_id, _, _ in eval_loader])
print(f'Done! [{time.perf_counter() - t:.2f}s] -> {buffer_size}')
kwargs = {}
if conf.model.name[:3] == 'PNA':
kwargs['deg'] = data.adj_t.storage.rowcount()
GNN = getattr(models, conf.model.name)
model = GNN(
num_nodes=data.num_nodes,
in_channels=in_channels,
out_channels=out_channels,
pool_size=params.pool_size,
buffer_size=buffer_size,
**params.architecture,
**kwargs,
).to(device)
optimizer = torch.optim.Adam([
dict(params=model.reg_modules.parameters(),
weight_decay=params.reg_weight_decay),
dict(params=model.nonreg_modules.parameters(),
weight_decay=params.nonreg_weight_decay)
], lr=params.lr)
t = time.perf_counter()
print('Fill history...', end=' ', flush=True)
mini_test(model, eval_loader)
print(f'Done! [{time.perf_counter() - t:.2f}s]')
best_val_acc = test_acc = 0
for epoch in range(1, params.epochs + 1):
loss = mini_train(model, train_loader, criterion, optimizer,
params.max_steps, grad_norm)
out = mini_test(model, eval_loader)
train_acc = compute_acc(out, data.y, data.train_mask)
if conf.dataset.name != 'ppi':
val_acc = compute_acc(out, data.y, data.val_mask)
tmp_test_acc = compute_acc(out, data.y, data.test_mask)
else:
val_acc = compute_acc(full_test(model, val_data), val_data.y)
tmp_test_acc = compute_acc(full_test(model, test_data),
test_data.y)
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
if epoch % conf.log_every == 0:
print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}')
print('=========================')
print(f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')
if __name__ == "__main__":
main()
# Benchmark on Small-scale Graphs
# Benchmark on Small-Scale Graphs
```
python main.py model=gcn dataset=cora root=/tmp/datasets device=0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment