Unverified Commit 549df65a authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix example case: examples/pytorch/ogb/ogbn-proteins and...

[Bugfix] Fix example case: examples/pytorch/ogb/ogbn-proteins and examples/pytorch/ogb/ogbn-products (#4080)

* [Bugfix] Fix ogbn-gat-proteins/products examples

* Remove unused BatchSampler definition

* Remove comments to ease reading/reviewing

* Remove dataloader wrapper
parent cac3720b
...@@ -22,7 +22,6 @@ from torch import nn ...@@ -22,7 +22,6 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
from models import GAT from models import GAT
from utils import BatchSampler, DataLoaderWrapper
epsilon = 1 - math.log(2) epsilon = 1 - math.log(2)
...@@ -219,15 +218,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -219,15 +218,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
n_train_samples = train_idx.shape[0] n_train_samples = train_idx.shape[0]
train_batch_size = (n_train_samples + 29) // 30 train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)])
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size, shuffle=True), batch_size=train_batch_size, shuffle=True,
num_workers=4, num_workers=4,
) )
)
eval_batch_size = 32768 eval_batch_size = 32768
eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)]) eval_sampler = MultiLayerNeighborSampler([15 for _ in range(args.n_layers)])
...@@ -238,15 +235,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -238,15 +235,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
test_idx_during_training = test_idx test_idx_during_training = test_idx
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()])
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
batch_sampler=BatchSampler(len(eval_idx), batch_size=eval_batch_size, shuffle=False), batch_size=eval_batch_size, shuffle=False,
num_workers=4, num_workers=4,
) )
)
model = gen_model(args).to(device) model = gen_model(args).to(device)
...@@ -292,7 +287,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -292,7 +287,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if epoch == args.n_epochs or epoch % args.log_every == 0: if epoch == args.n_epochs or epoch % args.log_every == 0:
print( print(
f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2s}\n" f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
f"Loss: {loss:.4f}, Score: {score:.4f}\n" f"Loss: {loss:.4f}, Score: {score:.4f}\n"
f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n" f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}" f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
...@@ -308,15 +303,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -308,15 +303,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.estimation_mode: if args.estimation_mode:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
batch_sampler=BatchSampler(len(test_idx), batch_size=eval_batch_size, shuffle=False), batch_size=eval_batch_size, shuffle=False,
num_workers=4, num_workers=4,
) )
)
final_test_score = evaluate( final_test_score = evaluate(
args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper args, model, eval_dataloader, labels, train_idx, val_idx, test_idx, criterion, evaluator_wrapper
)[2] )[2]
......
import torch
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size, shuffle=False):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if not self.shuffle:
perm = torch.arange(start=0, end=self.n)
while True:
if self.shuffle:
perm = torch.randperm(self.n)
shuf = perm.split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
...@@ -21,7 +21,6 @@ from ogb.nodeproppred import DglNodePropPredDataset, Evaluator ...@@ -21,7 +21,6 @@ from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
from models import GAT from models import GAT
from utils import BatchSampler, DataLoaderWrapper
device = None device = None
dataset = "ogbn-proteins" dataset = "ogbn-proteins"
...@@ -178,27 +177,23 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -178,27 +177,23 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
# batch_size = len(train_idx) # batch_size = len(train_idx)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
batch_sampler=BatchSampler(len(train_idx), batch_size=train_batch_size), batch_size=train_batch_size,
num_workers=10, num_workers=10,
) )
)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)]) eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoader(
DataLoader(
graph.cpu(), graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]), torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler, eval_sampler,
batch_sampler=BatchSampler(graph.number_of_nodes(), batch_size=65536), batch_size=65536,
num_workers=10, num_workers=10,
) )
)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
......
...@@ -88,29 +88,3 @@ class Logger(object): ...@@ -88,29 +88,3 @@ class Logger(object):
r = best_result[:, 3] r = best_result[:, 3]
print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}") print(f" Final Test: {r.mean():.2f} ± {r.std():.2f}")
class DataLoaderWrapper(object):
def __init__(self, dataloader):
self.iter = iter(dataloader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iter)
except Exception:
raise StopIteration() from None
class BatchSampler(object):
def __init__(self, n, batch_size):
self.n = n
self.batch_size = batch_size
def __iter__(self):
while True:
shuf = torch.randperm(self.n).split(self.batch_size)
for shuf_batch in shuf:
yield shuf_batch
yield None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment