Commit 3e7d0d2e authored by rusty1s's avatar rusty1s
Browse files

comments to train_gcn

parent 6a0c21bd
......@@ -23,7 +23,9 @@ from torch_geometric_autoscale import metis, permute, SubgraphLoader
class GNN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers):
# pool_size determines the number of pinned CPU buffers
# buffer_size determines the size of pinned CPU buffers
# buffer_size determines the size of pinned CPU buffers,
# i.e. the maximum number of out-of-mini-batch nodes
super(GNN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size=2, buffer_size=5000)
......
......@@ -16,7 +16,7 @@ args = parser.parse_args()
torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='Cora')
data, in_channels, out_channels = get_data(args.root, name='cora')
# Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
......@@ -36,8 +36,8 @@ model = GCN(
num_layers=2,
dropout=0.5,
drop_input=True,
pool_size=2,
buffer_size=1000,
pool_size=1, # Number of pinned CPU buffers
buffer_size=500, # Size of pinned CPU buffers (max #out-of-batch nodes)
).to(device)
optimizer = torch.optim.Adam([
......@@ -50,13 +50,12 @@ criterion = torch.nn.CrossEntropyLoss()
def train(model, loader, optimizer):
model.train()
for batch, batch_size, n_id, offset, count in loader:
for batch, *args in loader:
batch = batch.to(model.device)
train_mask = batch.train_mask[:batch_size]
optimizer.zero_grad()
out = model(batch.x, batch.adj_t, batch_size, n_id, offset, count)
loss = criterion(out[train_mask], batch.y[:batch_size][train_mask])
out = model(batch.x, batch.adj_t, *args)
train_mask = batch.train_mask[:out.size(0)]
loss = criterion(out[train_mask], batch.y[:out.size(0)][train_mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
......@@ -66,6 +65,7 @@ def train(model, loader, optimizer):
def test(model, data):
model.eval()
# Full-batch inference since the graph is small
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask)
......@@ -74,7 +74,8 @@ def test(model, data):
return train_acc, val_acc, test_acc
test(model, data) # Fill history.
test(model, data) # Fill the history.
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train(model, loader, optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment