Unverified Commit 35e5bca9 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the MLP example. (#6593)

parent 3af61c6b
......@@ -99,23 +99,24 @@ def train(
preds = torch.zeros(labels.shape[0], n_classes)
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
with dataloader.enable_cpu_affinity():
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach()
loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
count = len(new_train_idx)
loss_sum += loss.item() * count
total += count
count = len(new_train_idx)
loss_sum += loss.item() * count
total += count
preds = preds.to(train_idx.device)
return (
......@@ -143,11 +144,12 @@ def evaluate(
eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
for _ in range(eval_times):
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
with dataloader.enable_cpu_affinity():
for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred
pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred
preds /= eval_times
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment