Unverified Commit 74f01405 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example] Rename NodeDataLoader to DataLoader in GraphSAGE example (#3972)



* rename

* Update node_classification.py

* more fixes...
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 22d7f924
...@@ -97,14 +97,14 @@ def main(args): ...@@ -97,14 +97,14 @@ def main(args):
# train sampler # train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers) sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
loader = dgl.dataloading.NodeDataLoader( loader = dgl.dataloading.DataLoader(
g, {category: train_idx}, sampler, g, {category: train_idx}, sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0) batch_size=args.batch_size, shuffle=True, num_workers=0)
# validation sampler # validation sampler
# we do not use full neighbor to save computation resources # we do not use full neighbor to save computation resources
val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers) val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
val_loader = dgl.dataloading.NodeDataLoader( val_loader = dgl.dataloading.DataLoader(
g, {category: val_idx}, val_sampler, g, {category: val_idx}, val_sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0) batch_size=args.batch_size, shuffle=True, num_workers=0)
......
...@@ -350,7 +350,7 @@ class EntityClassify(nn.Module): ...@@ -350,7 +350,7 @@ class EntityClassify(nn.Module):
for k in g.ntypes} for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes}, {k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler, sampler,
...@@ -445,7 +445,7 @@ class EntityClassify_HeteroAPI(nn.Module): ...@@ -445,7 +445,7 @@ class EntityClassify_HeteroAPI(nn.Module):
for k in g.ntypes} for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes}, {k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler, sampler,
......
...@@ -8,7 +8,7 @@ import torch as th ...@@ -8,7 +8,7 @@ import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.dataloading import MultiLayerNeighborSampler, NodeDataLoader from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from torchmetrics.functional import accuracy from torchmetrics.functional import accuracy
from tqdm import tqdm from tqdm import tqdm
...@@ -19,7 +19,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F ...@@ -19,7 +19,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
fanouts = [int(fanout) for fanout in args.fanout.split(',')] fanouts = [int(fanout) for fanout in args.fanout.split(',')]
sampler = MultiLayerNeighborSampler(fanouts) sampler = MultiLayerNeighborSampler(fanouts)
train_loader = NodeDataLoader( train_loader = DataLoader(
g, g,
target_idx[train_idx], target_idx[train_idx],
sampler, sampler,
...@@ -30,7 +30,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F ...@@ -30,7 +30,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
drop_last=False) drop_last=False)
# The datasets do not have a validation subset, use the train subset # The datasets do not have a validation subset, use the train subset
val_loader = NodeDataLoader( val_loader = DataLoader(
g, g,
target_idx[train_idx], target_idx[train_idx],
sampler, sampler,
...@@ -42,7 +42,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F ...@@ -42,7 +42,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
# -1 for sampling all neighbors # -1 for sampling all neighbors
test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts)) test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
test_loader = NodeDataLoader( test_loader = DataLoader(
g, g,
target_idx[test_idx], target_idx[test_idx],
test_sampler, test_sampler,
......
...@@ -38,7 +38,7 @@ def load_data(data_name, get_norm=False, inv_target=False): ...@@ -38,7 +38,7 @@ def load_data(data_name, get_norm=False, inv_target=False):
category_id = hg.ntypes.index(category) category_id = hg.ntypes.index(category)
g = dgl.to_homogeneous(hg, edata=edata) g = dgl.to_homogeneous(hg, edata=edata)
# Rename the fields as they can be changed by for example NodeDataLoader # Rename the fields as they can be changed by for example DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE) g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID) g.ndata['type_id'] = g.ndata.pop(dgl.NID)
node_ids = th.arange(g.num_nodes()) node_ids = th.arange(g.num_nodes())
......
import torch import torch
import dgl import dgl
from dgl.dataloading.dataloader import EdgeCollator from dgl._dataloading.dataloader import EdgeCollator
from dgl.dataloading import BlockSampler from dgl._dataloading import BlockSampler
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_storages from dgl._dataloading.pytorch import _pop_subgraph_storage, _pop_storages
from dgl.base import DGLError from dgl.base import DGLError
from functools import partial from functools import partial
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment