Commit ff60382e authored by rusty1s's avatar rusty1s
Browse files

adjust small benchmark script

parent d0564e2e
......@@ -16,7 +16,7 @@ args = parser.parse_args()
torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='cora')
data, in_channels, out_channels = get_data(args.root, name='Cora')
# Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
......@@ -63,7 +63,7 @@ def train(model, loader, optimizer):
@torch.no_grad()
def test(data, model):
def test(model, data):
model.eval()
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
......@@ -78,7 +78,7 @@ test(data, model) # Fill history.
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train(model, loader, optimizer)
train_acc, val_acc, tmp_test_acc = test(data, model)
train_acc, val_acc, tmp_test_acc = test(model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
......
# Benchmark on Small-scale Graphs
```
python main.py
python main.py model=gcn dataset=cora root=/tmp/datasets device=0
```
You can choose between the following models and datasets:
* **Models:** `gcn`, `gat`, `appnp`, `gcn2`
* **Datasets:** `gcn`, `gat`, `appnp`, `gcn2`
import hydra
from tqdm import tqdm
from omegaconf import OmegaConf
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from scaling_gnns import get_data, models, SubgraphLoader, compute_acc
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, models, compute_acc)
torch.manual_seed(123)
criterion = torch.nn.CrossEntropyLoss()
def train(run, data, model, loader, optimizer, grad_norm=None):
def train(run, model, loader, optimizer, grad_norm=None):
model.train()
train_mask = data.train_mask
train_mask = train_mask[:, run] if train_mask.dim() == 2 else train_mask
total_loss = total_examples = 0
for info in loader:
info = info.to(model.device)
batch_size, n_id, adj_t, e_id = info
y = data.y[n_id[:batch_size]]
mask = train_mask[n_id[:batch_size]]
for batch, batch_size, n_id, _, _ in loader:
batch = batch.to(model.device)
n_id = n_id.to(model.device)
mask = batch.train_mask[:batch_size]
mask = mask[:, run] if mask.dim() == 2 else mask
if mask.sum() == 0:
continue
optimizer.zero_grad()
out = model(data.x[n_id], adj_t, batch_size, n_id)
loss = criterion(out[mask], y[mask])
out = model(batch.x, batch.adj_t, batch_size, n_id)
loss = criterion(out[mask], batch.y[:batch_size][mask])
loss.backward()
if grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
......@@ -42,7 +40,7 @@ def train(run, data, model, loader, optimizer, grad_norm=None):
@torch.no_grad()
def test(run, data, model):
def test(run, model, data):
model.eval()
val_mask = data.val_mask
......@@ -75,25 +73,21 @@ def main(conf):
elif conf.model.loop:
data.adj_t = data.adj_t.set_diag()
loader = SubgraphLoader(
data.adj_t,
batch_size=params.batch_size,
use_metis=True,
num_parts=params.num_parts,
shuffle=True,
num_workers=params.num_workers,
path=f'../../../metis/{model_name.lower()}_{dataset_name.lower()}',
log=False,
)
perm, ptr = metis(data.adj_t, num_parts=params.num_parts, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=params.batch_size,
shuffle=True, num_workers=params.num_workers,
persistent_workers=params.num_workers > 0)
data = data.to(device)
data = data.clone().to(device) # Let's just store all data on GPU...
GNN = getattr(models, model_name)
model = GNN(
num_nodes=data.num_nodes,
in_channels=in_channels,
out_channels=out_channels,
device=device,
device=device, # Put histories on GPU.
**params.architecture,
).to(device)
......@@ -108,14 +102,12 @@ def main(conf):
weight_decay=params.nonreg_weight_decay)
], lr=params.lr)
with torch.no_grad(): # Fill history.
model.eval()
model(data.x, data.adj_t)
test(0, model, data) # Fill history.
best_val_acc = 0
for epoch in range(params.epochs):
train(run, data, model, loader, optimizer, params.grad_norm)
val_acc, test_acc = test(run, data, model)
train(run, model, loader, optimizer, params.grad_norm)
val_acc, test_acc = test(run, model, data)
if val_acc > best_val_acc:
best_val_acc = val_acc
results[run] = test_acc
......
......@@ -74,7 +74,9 @@ class ScalableGNN(torch.nn.Module):
and n_id is not None and offset is not None
and count is not None)
if batch_size is not None and not self._async:
if (batch_size is not None and not self._async
and str(self.emb_device) == 'cpu'
and str(self.device)[:4] == 'cuda'):
warnings.warn('Asynchronous I/O disabled, although history and '
'model sit on different devices.')
......
......@@ -61,8 +61,9 @@ class GCN(ScalableGNN):
def reset_parameters(self):
super(GCN, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
if self.linear:
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment