"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "909742dbd6873052995dc6cd5f4150ff238015d2"
Unverified Commit 7f50a6da authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix mlp example (#4631)

* Fix mlp example

* Remove unused file
parent 188bc2bf
...@@ -21,7 +21,6 @@ from torch import nn ...@@ -21,7 +21,6 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
from models import MLP from models import MLP
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2) epsilon = 1 - math.log(2)
...@@ -107,6 +106,7 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval ...@@ -107,6 +106,7 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
loss_sum += loss.item() * count loss_sum += loss.item() * count
total += count total += count
preds = preds.to(train_idx.device)
return ( return (
loss_sum / total, loss_sum / total,
evaluator(preds[train_idx], labels[train_idx]), evaluator(preds[train_idx], labels[train_idx]),
...@@ -152,14 +152,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -152,14 +152,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size = 4096 train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True), batch_size=train_batch_size,
num_workers=4, shuffle=True,
) num_workers=4
) )
eval_batch_size = 4096 eval_batch_size = 4096
...@@ -168,14 +167,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -168,14 +167,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu()])
else: else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False), batch_size=eval_batch_size,
num_workers=4, shuffle=False,
) num_workers=4
) )
model = gen_model(args).to(device) model = gen_model(args).to(device)
...@@ -195,7 +193,6 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -195,7 +193,6 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for epoch in range(1, args.n_epochs + 1): for epoch in range(1, args.n_epochs + 1):
tic = time.time() tic = time.time()
loss, score = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper) loss, score = train(args, model, train_dataloader, labels, train_idx, criterion, optimizer, evaluator_wrapper)
toc = time.time() toc = time.time()
...@@ -233,14 +230,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -233,14 +230,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.eval_last: if args.eval_last:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False), batch_size=eval_batch_size,
num_workers=4, shuffle=False,
) num_workers=4
) )
final_test_score = evaluate( final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
......
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
while True:
if self.shuffle:
perm = torch.randperm(self.n)
else:
perm = torch.arange(start=0, end=self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment