Commit dcce414c authored by rusty1s's avatar rusty1s
Browse files

edge dropout

parent fa10cf1b
......@@ -125,6 +125,7 @@ params:
shared_weights: false
alpha: 0.1
theta: 0.5
edge_dropout: 0.8
num_parts: 150
batch_size: 1
max_steps: 151
......
......@@ -98,6 +98,7 @@ params:
drop_input: false
batch_norm: false
residual: false
edge_dropout: 0.8
num_parts: 150
batch_size: 1
max_steps: 151
......
......@@ -7,17 +7,18 @@ from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, EvalSubgraphLoader,
models, compute_micro_f1)
models, compute_micro_f1, dropout)
from torch_geometric_autoscale.data import get_ppi
torch.manual_seed(123)
def mini_train(model, loader, criterion, optimizer, max_steps, grad_norm=None):
def mini_train(model, loader, criterion, optimizer, max_steps, grad_norm=None,
edge_dropout=0.0):
model.train()
total_loss = total_examples = 0
for i, (batch, batch_size, n_id, offset, count) in enumerate(loader):
for i, (batch, batch_size, *args) in enumerate(loader):
x = batch.x.to(model.device)
adj_t = batch.adj_t.to(model.device)
y = batch.y[:batch_size].to(model.device)
......@@ -26,8 +27,11 @@ def mini_train(model, loader, criterion, optimizer, max_steps, grad_norm=None):
if train_mask.sum() == 0:
continue
# We make use of edge dropout on ogbn-products to avoid overfitting.
adj_t = dropout(adj_t, p=edge_dropout)
optimizer.zero_grad()
out = model(x, adj_t, batch_size, n_id, offset, count)
out = model(x, adj_t, batch_size, *args)
loss = criterion(out[train_mask], y[train_mask])
loss.backward()
if grad_norm is not None:
......@@ -37,7 +41,7 @@ def mini_train(model, loader, criterion, optimizer, max_steps, grad_norm=None):
total_loss += float(loss) * int(train_mask.sum())
total_examples += int(train_mask.sum())
# We abort after a fixed number of steps to refresh histories...
# We may abort after a fixed number of steps to refresh histories...
if (i + 1) >= max_steps and (i + 1) < len(loader):
break
......@@ -61,6 +65,10 @@ def main(conf):
conf.model.params = conf.model.params[conf.dataset.name]
params = conf.model.params
print(OmegaConf.to_yaml(conf))
try:
edge_dropout = params.edge_dropout
except: # noqa
edge_dropout = 0.0
grad_norm = None if isinstance(params.grad_norm, str) else params.grad_norm
device = f'cuda:{conf.device}' if torch.cuda.is_available() else 'cpu'
......@@ -142,7 +150,7 @@ def main(conf):
best_val_acc = test_acc = 0
for epoch in range(1, params.epochs + 1):
loss = mini_train(model, train_loader, criterion, optimizer,
params.max_steps, grad_norm)
params.max_steps, grad_norm, edge_dropout)
out = mini_test(model, eval_loader)
train_acc = compute_micro_f1(out, data.y, data.train_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment