Commit 3e7d0d2e authored by rusty1s's avatar rusty1s
Browse files

comments to train_gcn

parent 6a0c21bd
...@@ -23,7 +23,9 @@ from torch_geometric_autoscale import metis, permute, SubgraphLoader ...@@ -23,7 +23,9 @@ from torch_geometric_autoscale import metis, permute, SubgraphLoader
class GNN(ScalableGNN): class GNN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers): def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers):
# pool_size determines the number of pinned CPU buffers # pool_size determines the number of pinned CPU buffers
# buffer_size determines the size of pinned CPU buffers # buffer_size determines the size of pinned CPU buffers,
# i.e. the maximum number of out-of-mini-batch nodes
super(GNN, self).__init__(num_nodes, hidden_channels, num_layers, super(GNN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size=2, buffer_size=5000) pool_size=2, buffer_size=5000)
......
...@@ -16,7 +16,7 @@ args = parser.parse_args() ...@@ -16,7 +16,7 @@ args = parser.parse_args()
torch.manual_seed(12345) torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='Cora') data, in_channels, out_channels = get_data(args.root, name='cora')
# Pre-process adjacency matrix for GCN: # Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True) data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
...@@ -36,8 +36,8 @@ model = GCN( ...@@ -36,8 +36,8 @@ model = GCN(
num_layers=2, num_layers=2,
dropout=0.5, dropout=0.5,
drop_input=True, drop_input=True,
pool_size=2, pool_size=1, # Number of pinned CPU buffers
buffer_size=1000, buffer_size=500, # Size of pinned CPU buffers (max #out-of-batch nodes)
).to(device) ).to(device)
optimizer = torch.optim.Adam([ optimizer = torch.optim.Adam([
...@@ -50,13 +50,12 @@ criterion = torch.nn.CrossEntropyLoss() ...@@ -50,13 +50,12 @@ criterion = torch.nn.CrossEntropyLoss()
def train(model, loader, optimizer): def train(model, loader, optimizer):
model.train() model.train()
for batch, batch_size, n_id, offset, count in loader: for batch, *args in loader:
batch = batch.to(model.device) batch = batch.to(model.device)
train_mask = batch.train_mask[:batch_size]
optimizer.zero_grad() optimizer.zero_grad()
out = model(batch.x, batch.adj_t, batch_size, n_id, offset, count) out = model(batch.x, batch.adj_t, *args)
loss = criterion(out[train_mask], batch.y[:batch_size][train_mask]) train_mask = batch.train_mask[:out.size(0)]
loss = criterion(out[train_mask], batch.y[:out.size(0)][train_mask])
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
...@@ -66,6 +65,7 @@ def train(model, loader, optimizer): ...@@ -66,6 +65,7 @@ def train(model, loader, optimizer):
def test(model, data): def test(model, data):
model.eval() model.eval()
# Full-batch inference since the graph is small
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu() out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask) train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask) val_acc = compute_acc(out, data.y, data.val_mask)
...@@ -74,7 +74,8 @@ def test(model, data): ...@@ -74,7 +74,8 @@ def test(model, data):
return train_acc, val_acc, test_acc return train_acc, val_acc, test_acc
test(model, data) # Fill history. test(model, data) # Fill the history.
best_val_acc = test_acc = 0 best_val_acc = test_acc = 0
for epoch in range(1, 201): for epoch in range(1, 201):
train(model, loader, optimizer) train(model, loader, optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment