Unverified Commit 74f01405 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example] Rename NodeDataLoader to DataLoader in GraphSAGE example (#3972)



* rename

* Update node_classification.py

* more fixes...
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 22d7f924
...@@ -103,7 +103,7 @@ def main(args): ...@@ -103,7 +103,7 @@ def main(args):
tr_recall = 0 tr_recall = 0
tr_auc = 0 tr_auc = 0
tr_blk = 0 tr_blk = 0
train_dataloader = dgl.dataloading.NodeDataLoader(graph, train_dataloader = dgl.dataloading.DataLoader(graph,
train_idx, train_idx,
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -135,7 +135,7 @@ def main(args): ...@@ -135,7 +135,7 @@ def main(args):
# validation # validation
model.eval() model.eval()
val_dataloader = dgl.dataloading.NodeDataLoader(graph, val_dataloader = dgl.dataloading.DataLoader(graph,
val_idx, val_idx,
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -159,7 +159,7 @@ def main(args): ...@@ -159,7 +159,7 @@ def main(args):
model.eval() model.eval()
if args.early_stop: if args.early_stop:
model.load_state_dict(th.load('es_checkpoint.pt')) model.load_state_dict(th.load('es_checkpoint.pt'))
test_dataloader = dgl.dataloading.NodeDataLoader(graph, test_dataloader = dgl.dataloading.DataLoader(graph,
test_idx, test_idx,
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
......
...@@ -38,8 +38,6 @@ sampler = dgl.dataloading.ClusterGCNSampler( ...@@ -38,8 +38,6 @@ sampler = dgl.dataloading.ClusterGCNSampler(
prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask']) prefetch_ndata=['feat', 'label', 'train_mask', 'val_mask', 'test_mask'])
# DataLoader for generic dataloading with a graph, a set of indices (any indices, like # DataLoader for generic dataloading with a graph, a set of indices (any indices, like
# partition IDs here), and a graph sampler. # partition IDs here), and a graph sampler.
# NodeDataLoader and EdgeDataLoader are simply special cases of DataLoader where the
# indices are guaranteed to be node and edge IDs.
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
graph, graph,
torch.arange(num_partitions).to('cuda'), torch.arange(num_partitions).to('cuda'),
......
...@@ -87,7 +87,7 @@ class GCMCGraphConv(nn.Module): ...@@ -87,7 +87,7 @@ class GCMCGraphConv(nn.Module):
if weight is not None: if weight is not None:
feat = dot_or_identity(feat, weight, self.device) feat = dot_or_identity(feat, weight, self.device)
feat = feat * self.dropout(cj) feat = feat * self.dropout(cj).view(-1, 1)
graph.srcdata['h'] = feat graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
...@@ -342,7 +342,7 @@ class BiDecoder(nn.Module): ...@@ -342,7 +342,7 @@ class BiDecoder(nn.Module):
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr')) graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
basis_out.append(graph.edata['sr']) basis_out.append(graph.edata['sr'])
out = th.cat(basis_out, dim=1) out = th.cat(basis_out, dim=1)
out = self.combine_basis(out) #out = self.combine_basis(out)
return out return out
class DenseBiDecoder(nn.Module): class DenseBiDecoder(nn.Module):
......
...@@ -54,7 +54,7 @@ class SAGE(nn.Module): ...@@ -54,7 +54,7 @@ class SAGE(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
th.arange(g.num_nodes()).to(g.device), th.arange(g.num_nodes()).to(g.device),
sampler, sampler,
......
...@@ -35,7 +35,7 @@ class SAGE(nn.Module): ...@@ -35,7 +35,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers)
...@@ -84,7 +84,7 @@ sampler = dgl.dataloading.NeighborSampler( ...@@ -84,7 +84,7 @@ sampler = dgl.dataloading.NeighborSampler(
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False) drop_last=False, num_workers=0, use_uva=False)
valid_dataloader = dgl.dataloading.NodeDataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False) drop_last=False, num_workers=0, use_uva=False)
......
...@@ -136,9 +136,9 @@ class DataModule(LightningDataModule): ...@@ -136,9 +136,9 @@ class DataModule(LightningDataModule):
num_workers=self.num_workers) num_workers=self.num_workers)
def val_dataloader(self): def val_dataloader(self):
# Note that the validation data loader is a NodeDataLoader # Note that the validation data loader is a DataLoader
# as we want to evaluate all the node embeddings. # as we want to evaluate all the node embeddings.
return dgl.dataloading.NodeDataLoader( return dgl.dataloading.DataLoader(
self.g, self.g,
np.arange(self.g.num_nodes()), np.arange(self.g.num_nodes()),
self.sampler, self.sampler,
......
...@@ -41,7 +41,7 @@ class SAGE(LightningModule): ...@@ -41,7 +41,7 @@ class SAGE(LightningModule):
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat'] g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0)) persistent_workers=(num_workers > 0))
......
...@@ -77,7 +77,7 @@ class SAGE(nn.Module): ...@@ -77,7 +77,7 @@ class SAGE(nn.Module):
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers) batch_size=1000, shuffle=False, drop_last=False, num_workers=num_workers)
if buffer_device is None: if buffer_device is None:
......
...@@ -38,7 +38,7 @@ class SAGE(nn.Module): ...@@ -38,7 +38,7 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size, num_workers, buffer_device=None):
feat = g.ndata['feat'] feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers)
...@@ -85,7 +85,7 @@ sampler = dgl.dataloading.NeighborSampler( ...@@ -85,7 +85,7 @@ sampler = dgl.dataloading.NeighborSampler(
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu) drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
valid_dataloader = dgl.dataloading.NodeDataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True, graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=not args.pure_gpu) drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
......
...@@ -71,7 +71,7 @@ global_num_nodes = g.number_of_nodes() ...@@ -71,7 +71,7 @@ global_num_nodes = g.number_of_nodes()
fanouts = [args.knn_k-1 for i in range(args.num_conv + 1)] fanouts = [args.knn_k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
test_loader = dgl.dataloading.NodeDataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
...@@ -135,7 +135,7 @@ for level in range(args.levels): ...@@ -135,7 +135,7 @@ for level in range(args.levels):
g = dataset.gs[0] g = dataset.gs[0]
g.ndata['pred_den'] = torch.zeros((g.number_of_nodes())) g.ndata['pred_den'] = torch.zeros((g.number_of_nodes()))
g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2)) g.edata['prob_conn'] = torch.zeros((g.number_of_edges(), 2))
test_loader = dgl.dataloading.NodeDataLoader( test_loader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
......
...@@ -73,7 +73,7 @@ def set_train_sampler_loader(g, k): ...@@ -73,7 +73,7 @@ def set_train_sampler_loader(g, k):
fanouts = [k-1 for i in range(args.num_conv + 1)] fanouts = [k-1 for i in range(args.num_conv + 1)]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
# fix the number of edges # fix the number of edges
train_dataloader = dgl.dataloading.NodeDataLoader( train_dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g, torch.arange(g.number_of_nodes()), sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
......
...@@ -76,7 +76,7 @@ class GAT(nn.Module): ...@@ -76,7 +76,7 @@ class GAT(nn.Module):
else: else:
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
th.arange(g.num_nodes()), th.arange(g.num_nodes()),
sampler, sampler,
......
...@@ -353,7 +353,7 @@ def prepare_data(args): ...@@ -353,7 +353,7 @@ def prepare_data(args):
# train sampler # train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(args['fanout']) sampler = dgl.dataloading.MultiLayerNeighborSampler(args['fanout'])
train_loader = dgl.dataloading.NodeDataLoader( train_loader = dgl.dataloading.DataLoader(
g, split_idx['train'], sampler, g, split_idx['train'], sampler,
batch_size=args['batch_size'], shuffle=True, num_workers=0) batch_size=args['batch_size'], shuffle=True, num_workers=0)
...@@ -439,7 +439,7 @@ def test(g, model, node_embed, y_true, device, split_idx, args): ...@@ -439,7 +439,7 @@ def test(g, model, node_embed, y_true, device, split_idx, args):
evaluator = Evaluator(name='ogbn-mag') evaluator = Evaluator(name='ogbn-mag')
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num_layers']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num_layers'])
loader = dgl.dataloading.NodeDataLoader( loader = dgl.dataloading.DataLoader(
g, {'paper': th.arange(g.num_nodes('paper'))}, sampler, g, {'paper': th.arange(g.num_nodes('paper'))}, sampler,
batch_size=16384, shuffle=False, num_workers=0) batch_size=16384, shuffle=False, num_workers=0)
......
...@@ -15,7 +15,7 @@ import torch ...@@ -15,7 +15,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
...@@ -220,7 +220,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -220,7 +220,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size = (n_train_samples + 29) // 30 train_batch_size = (n_train_samples + 29) // 30
train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler([10 for _ in range(args.n_layers)])
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
...@@ -239,7 +239,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -239,7 +239,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx_during_training.cpu()])
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
...@@ -309,7 +309,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -309,7 +309,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.estimation_mode: if args.estimation_mode:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
......
...@@ -68,7 +68,7 @@ class GAT(nn.Module): ...@@ -68,7 +68,7 @@ class GAT(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
th.arange(g.num_nodes()), th.arange(g.num_nodes()),
sampler, sampler,
...@@ -132,7 +132,7 @@ def run(args, device, data): ...@@ -132,7 +132,7 @@ def run(args, device, data):
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
train_nid, train_nid,
sampler, sampler,
......
...@@ -63,7 +63,7 @@ class SAGE(nn.Module): ...@@ -63,7 +63,7 @@ class SAGE(nn.Module):
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes).to(device) y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes).to(device)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
th.arange(g.num_nodes()), th.arange(g.num_nodes()),
sampler, sampler,
...@@ -124,7 +124,7 @@ def run(args, device, data): ...@@ -124,7 +124,7 @@ def run(args, device, data):
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')]) [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
train_nid, train_nid,
sampler, sampler,
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
...@@ -153,7 +153,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -153,7 +153,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size = 4096 train_batch_size = 4096
train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors train_sampler = MultiLayerNeighborSampler([0 for _ in range(args.n_layers)]) # no not sample neighbors
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
...@@ -169,7 +169,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -169,7 +169,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
else: else:
eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]) eval_idx = torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()])
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
eval_idx, eval_idx,
eval_sampler, eval_sampler,
...@@ -234,7 +234,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -234,7 +234,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if args.eval_last: if args.eval_last:
model.load_state_dict(best_model_state_dict) model.load_state_dict(best_model_state_dict)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
test_idx.cpu(), test_idx.cpu(),
eval_sampler, eval_sampler,
......
...@@ -15,7 +15,7 @@ import torch ...@@ -15,7 +15,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader from dgl.dataloading import DataLoader
from matplotlib.ticker import AutoMinorLocator, MultipleLocator from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn from torch import nn
...@@ -179,7 +179,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -179,7 +179,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)]) train_sampler = MultiLayerNeighborSampler([32 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
train_dataloader = DataLoaderWrapper( train_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
train_idx.cpu(), train_idx.cpu(),
train_sampler, train_sampler,
...@@ -191,7 +191,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running) ...@@ -191,7 +191,7 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)]) eval_sampler = MultiLayerNeighborSampler([100 for _ in range(args.n_layers)])
# sampler = MultiLayerFullNeighborSampler(args.n_layers) # sampler = MultiLayerFullNeighborSampler(args.n_layers)
eval_dataloader = DataLoaderWrapper( eval_dataloader = DataLoaderWrapper(
NodeDataLoader( DataLoader(
graph.cpu(), graph.cpu(),
torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]), torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
eval_sampler, eval_sampler,
......
...@@ -21,7 +21,7 @@ def train( ...@@ -21,7 +21,7 @@ def train(
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
labels: torch.Tensor, labels: torch.Tensor,
predict_category: str, predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader, dataloader: dgl.dataloading.DataLoader,
) -> Tuple[float]: ) -> Tuple[float]:
model.train() model.train()
...@@ -78,7 +78,7 @@ def validate( ...@@ -78,7 +78,7 @@ def validate(
hg: dgl.DGLHeteroGraph, hg: dgl.DGLHeteroGraph,
labels: torch.Tensor, labels: torch.Tensor,
predict_category: str, predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader = None, dataloader: dgl.dataloading.DataLoader = None,
eval_batch_size: int = None, eval_batch_size: int = None,
eval_num_workers: int = None, eval_num_workers: int = None,
mask: torch.Tensor = None, mask: torch.Tensor = None,
...@@ -173,7 +173,7 @@ def run(args: argparse.ArgumentParser) -> None: ...@@ -173,7 +173,7 @@ def run(args: argparse.ArgumentParser) -> None:
fanouts = [int(fanout) for fanout in args.fanouts.split(',')] fanouts = [int(fanout) for fanout in args.fanouts.split(',')]
train_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) train_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
train_dataloader = dgl.dataloading.NodeDataLoader( train_dataloader = dgl.dataloading.DataLoader(
hg, hg,
{predict_category: train_idx}, {predict_category: train_idx},
train_sampler, train_sampler,
...@@ -185,7 +185,7 @@ def run(args: argparse.ArgumentParser) -> None: ...@@ -185,7 +185,7 @@ def run(args: argparse.ArgumentParser) -> None:
if inferfence_mode == 'neighbor_sampler': if inferfence_mode == 'neighbor_sampler':
valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
valid_dataloader = dgl.dataloading.NodeDataLoader( valid_dataloader = dgl.dataloading.DataLoader(
hg, hg,
{predict_category: valid_idx}, {predict_category: valid_idx},
valid_sampler, valid_sampler,
...@@ -197,7 +197,7 @@ def run(args: argparse.ArgumentParser) -> None: ...@@ -197,7 +197,7 @@ def run(args: argparse.ArgumentParser) -> None:
if args.test_validation: if args.test_validation:
test_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) test_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
test_dataloader = dgl.dataloading.NodeDataLoader( test_dataloader = dgl.dataloading.DataLoader(
hg, hg,
{predict_category: test_idx}, {predict_category: test_idx},
test_sampler, test_sampler,
......
...@@ -286,7 +286,7 @@ class EntityClassify(nn.Module): ...@@ -286,7 +286,7 @@ class EntityClassify(nn.Module):
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
for i, layer in enumerate(self._layers): for i, layer in enumerate(self._layers):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
hg, hg,
{ntype: hg.nodes(ntype) for ntype in hg.ntypes}, {ntype: hg.nodes(ntype) for ntype in hg.ntypes},
sampler, sampler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment