Unverified Commit 74f01405 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example] Rename NodeDataLoader to DataLoader in GraphSAGE example (#3972)



* rename

* Update node_classification.py

* more fixes...
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 22d7f924
......@@ -97,14 +97,14 @@ def main(args):
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
loader = dgl.dataloading.NodeDataLoader(
loader = dgl.dataloading.DataLoader(
g, {category: train_idx}, sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
# validation sampler
# we do not use full neighbor to save computation resources
val_sampler = dgl.dataloading.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
val_loader = dgl.dataloading.NodeDataLoader(
val_loader = dgl.dataloading.DataLoader(
g, {category: val_idx}, val_sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
......
......@@ -350,7 +350,7 @@ class EntityClassify(nn.Module):
for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler,
......@@ -445,7 +445,7 @@ class EntityClassify_HeteroAPI(nn.Module):
for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = dgl.dataloading.DataLoader(
g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler,
......
......@@ -8,7 +8,7 @@ import torch as th
import torch.nn.functional as F
import dgl
from dgl.dataloading import MultiLayerNeighborSampler, NodeDataLoader
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from torchmetrics.functional import accuracy
from tqdm import tqdm
......@@ -19,7 +19,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
sampler = MultiLayerNeighborSampler(fanouts)
train_loader = NodeDataLoader(
train_loader = DataLoader(
g,
target_idx[train_idx],
sampler,
......@@ -30,7 +30,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
drop_last=False)
# The datasets do not have a validation subset, use the train subset
val_loader = NodeDataLoader(
val_loader = DataLoader(
g,
target_idx[train_idx],
sampler,
......@@ -42,7 +42,7 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
# -1 for sampling all neighbors
test_sampler = MultiLayerNeighborSampler([-1] * len(fanouts))
test_loader = NodeDataLoader(
test_loader = DataLoader(
g,
target_idx[test_idx],
test_sampler,
......
......@@ -38,7 +38,7 @@ def load_data(data_name, get_norm=False, inv_target=False):
category_id = hg.ntypes.index(category)
g = dgl.to_homogeneous(hg, edata=edata)
# Rename the fields as they can be changed by for example NodeDataLoader
# Rename the fields as they can be changed by for example DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID)
node_ids = th.arange(g.num_nodes())
......
import torch
import dgl
from dgl.dataloading.dataloader import EdgeCollator
from dgl.dataloading import BlockSampler
from dgl.dataloading.pytorch import _pop_subgraph_storage, _pop_storages
from dgl._dataloading.dataloader import EdgeCollator
from dgl._dataloading import BlockSampler
from dgl._dataloading.pytorch import _pop_subgraph_storage, _pop_storages
from dgl.base import DGLError
from functools import partial
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment