Commit ff60382e authored by rusty1s's avatar rusty1s
Browse files

adjust small benchmark script

parent d0564e2e
...@@ -16,7 +16,7 @@ args = parser.parse_args() ...@@ -16,7 +16,7 @@ args = parser.parse_args()
torch.manual_seed(12345) torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='cora') data, in_channels, out_channels = get_data(args.root, name='Cora')
# Pre-process adjacency matrix for GCN: # Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True) data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
...@@ -63,7 +63,7 @@ def train(model, loader, optimizer): ...@@ -63,7 +63,7 @@ def train(model, loader, optimizer):
@torch.no_grad() @torch.no_grad()
def test(data, model): def test(model, data):
model.eval() model.eval()
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu() out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
...@@ -78,7 +78,7 @@ test(data, model) # Fill history. ...@@ -78,7 +78,7 @@ test(data, model) # Fill history.
best_val_acc = test_acc = 0 best_val_acc = test_acc = 0
for epoch in range(1, 201): for epoch in range(1, 201):
train(model, loader, optimizer) train(model, loader, optimizer)
train_acc, val_acc, tmp_test_acc = test(data, model) train_acc, val_acc, tmp_test_acc = test(model, data)
if val_acc > best_val_acc: if val_acc > best_val_acc:
best_val_acc = val_acc best_val_acc = val_acc
test_acc = tmp_test_acc test_acc = tmp_test_acc
......
# Benchmark on Small-scale Graphs # Benchmark on Small-scale Graphs
``` ```
python main.py python main.py model=gcn dataset=cora root=/tmp/datasets device=0
```
You can choose between the following models and datasets:
* **Models:** `gcn`, `gat`, `appnp`, `gcn2`
* **Datasets:** `gcn`, `gat`, `appnp`, `gcn2`
import hydra import hydra
from tqdm import tqdm from tqdm import tqdm
from omegaconf import OmegaConf from omegaconf import OmegaConf
import torch import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.conv.gcn_conv import gcn_norm
from scaling_gnns import get_data, models, SubgraphLoader, compute_acc from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, models, compute_acc)
torch.manual_seed(123) torch.manual_seed(123)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
def train(run, data, model, loader, optimizer, grad_norm=None): def train(run, model, loader, optimizer, grad_norm=None):
model.train() model.train()
train_mask = data.train_mask
train_mask = train_mask[:, run] if train_mask.dim() == 2 else train_mask
total_loss = total_examples = 0 total_loss = total_examples = 0
for info in loader: for batch, batch_size, n_id, _, _ in loader:
info = info.to(model.device) batch = batch.to(model.device)
batch_size, n_id, adj_t, e_id = info n_id = n_id.to(model.device)
y = data.y[n_id[:batch_size]]
mask = train_mask[n_id[:batch_size]]
mask = batch.train_mask[:batch_size]
mask = mask[:, run] if mask.dim() == 2 else mask
if mask.sum() == 0: if mask.sum() == 0:
continue continue
optimizer.zero_grad() optimizer.zero_grad()
out = model(data.x[n_id], adj_t, batch_size, n_id) out = model(batch.x, batch.adj_t, batch_size, n_id)
loss = criterion(out[mask], y[mask]) loss = criterion(out[mask], batch.y[:batch_size][mask])
loss.backward() loss.backward()
if grad_norm is not None: if grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm)
...@@ -42,7 +40,7 @@ def train(run, data, model, loader, optimizer, grad_norm=None): ...@@ -42,7 +40,7 @@ def train(run, data, model, loader, optimizer, grad_norm=None):
@torch.no_grad() @torch.no_grad()
def test(run, data, model): def test(run, model, data):
model.eval() model.eval()
val_mask = data.val_mask val_mask = data.val_mask
...@@ -75,25 +73,21 @@ def main(conf): ...@@ -75,25 +73,21 @@ def main(conf):
elif conf.model.loop: elif conf.model.loop:
data.adj_t = data.adj_t.set_diag() data.adj_t = data.adj_t.set_diag()
loader = SubgraphLoader( perm, ptr = metis(data.adj_t, num_parts=params.num_parts, log=True)
data.adj_t, data = permute(data, perm, log=True)
batch_size=params.batch_size,
use_metis=True, loader = SubgraphLoader(data, ptr, batch_size=params.batch_size,
num_parts=params.num_parts, shuffle=True, num_workers=params.num_workers,
shuffle=True, persistent_workers=params.num_workers > 0)
num_workers=params.num_workers,
path=f'../../../metis/{model_name.lower()}_{dataset_name.lower()}',
log=False,
)
data = data.to(device) data = data.clone().to(device) # Let's just store all data on GPU...
GNN = getattr(models, model_name) GNN = getattr(models, model_name)
model = GNN( model = GNN(
num_nodes=data.num_nodes, num_nodes=data.num_nodes,
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
device=device, device=device, # Put histories on GPU.
**params.architecture, **params.architecture,
).to(device) ).to(device)
...@@ -108,14 +102,12 @@ def main(conf): ...@@ -108,14 +102,12 @@ def main(conf):
weight_decay=params.nonreg_weight_decay) weight_decay=params.nonreg_weight_decay)
], lr=params.lr) ], lr=params.lr)
with torch.no_grad(): # Fill history. test(0, model, data) # Fill history.
model.eval()
model(data.x, data.adj_t)
best_val_acc = 0 best_val_acc = 0
for epoch in range(params.epochs): for epoch in range(params.epochs):
train(run, data, model, loader, optimizer, params.grad_norm) train(run, model, loader, optimizer, params.grad_norm)
val_acc, test_acc = test(run, data, model) val_acc, test_acc = test(run, model, data)
if val_acc > best_val_acc: if val_acc > best_val_acc:
best_val_acc = val_acc best_val_acc = val_acc
results[run] = test_acc results[run] = test_acc
......
...@@ -74,7 +74,9 @@ class ScalableGNN(torch.nn.Module): ...@@ -74,7 +74,9 @@ class ScalableGNN(torch.nn.Module):
and n_id is not None and offset is not None and n_id is not None and offset is not None
and count is not None) and count is not None)
if batch_size is not None and not self._async: if (batch_size is not None and not self._async
and str(self.emb_device) == 'cpu'
and str(self.device)[:4] == 'cuda'):
warnings.warn('Asynchronous I/O disabled, although history and ' warnings.warn('Asynchronous I/O disabled, although history and '
'model sit on different devices.') 'model sit on different devices.')
......
...@@ -61,8 +61,9 @@ class GCN(ScalableGNN): ...@@ -61,8 +61,9 @@ class GCN(ScalableGNN):
def reset_parameters(self): def reset_parameters(self):
super(GCN, self).reset_parameters() super(GCN, self).reset_parameters()
for lin in self.lins: if self.linear:
lin.reset_parameters() for lin in self.lins:
lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
conv.reset_parameters() conv.reset_parameters()
for bn in self.bns: for bn in self.bns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment