Unverified Commit 35e5bca9 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the MLP example. (#6593)

parent 3af61c6b
...@@ -99,23 +99,24 @@ def train( ...@@ -99,23 +99,24 @@ def train(
preds = torch.zeros(labels.shape[0], n_classes) preds = torch.zeros(labels.shape[0], n_classes)
for _input_nodes, output_nodes, subgraphs in dataloader: with dataloader.enable_cpu_affinity():
subgraphs = [b.to(device) for b in subgraphs] for _input_nodes, output_nodes, subgraphs in dataloader:
new_train_idx = list(range(len(output_nodes))) subgraphs = [b.to(device) for b in subgraphs]
new_train_idx = list(range(len(output_nodes)))
pred = model(subgraphs[0].srcdata["feat"]) pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred.cpu().detach() preds[output_nodes] = pred.cpu().detach()
loss = criterion( loss = criterion(
pred[new_train_idx], labels[output_nodes][new_train_idx] pred[new_train_idx], labels[output_nodes][new_train_idx]
) )
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
count = len(new_train_idx) count = len(new_train_idx)
loss_sum += loss.item() * count loss_sum += loss.item() * count
total += count total += count
preds = preds.to(train_idx.device) preds = preds.to(train_idx.device)
return ( return (
...@@ -143,11 +144,12 @@ def evaluate( ...@@ -143,11 +144,12 @@ def evaluate(
eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times. eval_times = 1 # Due to the limitation of memory capacity, we calculate the average of logits 'eval_times' times.
for _ in range(eval_times): for _ in range(eval_times):
for _input_nodes, output_nodes, subgraphs in dataloader: with dataloader.enable_cpu_affinity():
subgraphs = [b.to(device) for b in subgraphs] for _input_nodes, output_nodes, subgraphs in dataloader:
subgraphs = [b.to(device) for b in subgraphs]
pred = model(subgraphs[0].srcdata["feat"]) pred = model(subgraphs[0].srcdata["feat"])
preds[output_nodes] = pred preds[output_nodes] = pred
preds /= eval_times preds /= eval_times
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment