Unverified Commit 74f01405 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example] Rename NodeDataLoader to DataLoader in GraphSAGE example (#3972)



* rename

* Update node_classification.py

* more fixes...
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 22d7f924
......@@ -103,7 +103,7 @@ def main(args):
tr_recall = 0
tr_auc = 0
tr_blk = 0
train_dataloader = dgl.dataloading.NodeDataLoader(graph,
train_dataloader = dgl.dataloading.DataLoader(graph,
train_idx,
sampler,
batch_size=args.batch_size,
......@@ -135,7 +135,7 @@ def main(args):
# validation
model.eval()
val_dataloader = dgl.dataloading.NodeDataLoader(graph,
val_dataloader = dgl.dataloading.DataLoader(graph,
val_idx,
sampler,
batch_size=args.batch_size,
......@@ -159,7 +159,7 @@ def main(args):
model.eval()
if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt'))
test_dataloader = dgl.dataloading.NodeDataLoader(graph,
test_dataloader = dgl.dataloading.DataLoader(graph,
test_idx,
sampler,
batch_size=args.batch_size,
......
......@@ -38,8 +38,6 @@ sampler = dgl.dataloading.ClusterGCNSampler(
prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask'])
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler.
# NodeDataLoader and EdgeDataLoader are simply special cases of DataLoader where the
# indices are guaranteed to be node and edge IDs.
dataloader = dgl.dataloading.DataLoader(
graph,
torch.arange(num_partitions).to('cuda'),
......
......@@ -87,7 +87,7 @@ class GCMCGraphConv(nn.Module):
if weight is not None:
feat = dot_or_identity(feat, weight, self.device)
feat = feat * self.dropout(cj)
feat = feat * self.dropout(cj).view(-1, 1)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
......@@ -342,7 +342,7 @@ class BiDecoder(nn.Module):
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
basis_out.append(graph.edata['sr'])
out = th.cat(basis_out, dim=1)
out = self.combine_basis(out)
#out = self.combine_basis(out)
return out
class DenseBiDecoder(nn.Module):
......
......@@ -54,7 +54,7 @@ class SAGE(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()).to(g.device),
sampler,
......
......@@ -35,7 +35,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers)
......@@ -84,7 +84,7 @@ sampler = dgl.dataloading.NeighborSampler(
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False)
......
......@@ -136,9 +136,9 @@ class DataModule(LightningDataModule):
num_workers=self.num_workers)
def val_dataloader(self):
# Note that the validation data loader is a NodeDataLoader
# Note that the validation data loader is a DataLoader
# as we want to evaluate all the node embeddings.
return dgl.dataloading.NodeDataLoader(
return dgl.dataloading.DataLoader(
self.g,
np.arange(self.g.num_nodes()),
self.sampler,
......
......@@ -41,7 +41,7 @@ class SAGE(LightningModule):
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
......
......@@ -77,7 +77,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
if buffer_device is None:
......
......@@ -38,7 +38,7 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers)
......@@ -85,7 +85,7 @@ sampler = dgl.dataloading.NeighborSampler(
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
......
......@@ -71,7 +71,7 @@ global_num_nodes = g.number_of_nodes()
fanouts = [args.knn_k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
test_loader = dgl.dataloading.NodeDataLoader(
test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
......@@ -135,7 +135,7 @@ for level in range(args.levels):
g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.NodeDataLoader(
test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=False,
......
......@@ -73,7 +73,7 @@ def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges
train_dataloader = dgl.dataloading.NodeDataLoader(
train_dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size,
shuffle=True,
......
......@@ -76,7 +76,7 @@ class GAT(nn.Module):
else:
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()),
sampler,
......
......@@ -353,7 +353,7 @@ def prepare_data(args):
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(args['fanout'])
train_loader = dgl.dataloading.NodeDataLoader(
train_loader = dgl.dataloading.DataLoader(
g, split_idx['train'], sampler,
batch_size=args['batch_size'], shuffle=True, num_workers=0)
......@@ -439,7 +439,7 @@ def test(g, model, node_embed, y_true, device, split_idx, args):
evaluator = Evaluator(name='ogbn-mag')
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num_layers'])
loader = dgl.dataloading.NodeDataLoader(
loader = dgl.dataloading.DataLoader(
g, {'paper': th.arange(g.num_nodes('paper'))}, sampler,
batch_size=16384, shuffle=False, num_workers=0)
......
......@@ -15,7 +15,7 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
......@@ -220,7 +220,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)])
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
......@@ -239,7 +239,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
......@@ -309,7 +309,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.estimation_mode:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
......
......@@ -68,7 +68,7 @@ class GAT(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()),
sampler,
......@@ -132,7 +132,7 @@ def run(args, device, data):
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
sampler,
......
......@@ -63,7 +63,7 @@ class SAGE(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes).to(device)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes()),
sampler,
......@@ -124,7 +124,7 @@ def run(args, device, data):
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
sampler,
......
......@@ -14,7 +14,7 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
......@@ -153,7 +153,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
......@@ -169,7 +169,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
eval_idx,
eval_sampler,
......@@ -234,7 +234,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.eval_last:
model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
test_idx.cpu(),
eval_sampler,
......
......@@ -15,7 +15,7 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader
from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn
......@@ -179,7 +179,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
train_idx.cpu(),
train_sampler,
......@@ -191,7 +191,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoaderWrapper(
NodeDataLoader(
DataLoader(
graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler,
......
......@@ -21,7 +21,7 @@ def train(
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader,
dataloader: dgl.dataloading.DataLoader,
) -> Tuple[float]:
model.train()
......@@ -78,7 +78,7 @@ def validate(
hg: dgl.DGLHeteroGraph,
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader = None,
dataloader: dgl.dataloading.DataLoader = None,
eval_batch_size: int = None,
eval_num_workers: int = None,
mask: torch.Tensor = None,
......@@ -173,7 +173,7 @@ def run(args: argparse.ArgumentParser) -> None:
fanouts = [int(fanout) for fanout in args.fanouts.split(',')]
train_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
train_dataloader = dgl.dataloading.NodeDataLoader(
train_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: train_idx},
train_sampler,
......@@ -185,7 +185,7 @@ def run(args: argparse.ArgumentParser) -> None:
if inferfence_mode == 'neighbor_sampler':
valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
valid_dataloader = dgl.dataloading.NodeDataLoader(
valid_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: valid_idx},
valid_sampler,
......@@ -197,7 +197,7 @@ def run(args: argparse.ArgumentParser) -> None:
if args.test_validation:
test_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
test_dataloader = dgl.dataloading.NodeDataLoader(
test_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: test_idx},
test_sampler,
......
......@@ -286,7 +286,7 @@ class EntityClassify(nn.Module):
) -> Dict[str, torch.Tensor]:
for i, layer in enumerate(self._layers):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
hg,
{ntype: hg.nodes(ntype) for ntype in hg.ntypes},
sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment