"...models/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "21790df9a4f77bd9ec4db44de04594cb539457a7"
Unverified Commit a9f8f258 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Doc] Update GraphDataLoader in our docs (#2504)



* update graphdataloader in docs

* fix

* update examples

* fix sagpool
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e0189397
......@@ -26,7 +26,7 @@ Prediction* tasks.
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
def _collate_fn(batch):
......@@ -41,9 +41,9 @@ Prediction* tasks.
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
Loading *Node Property Prediction* datasets is similar, but note that
there is only one graph object in this kind of dataset.
......
......@@ -96,21 +96,14 @@ follows:
import dgl
import torch
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels
# create collate_fn
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# training
for epoch in range(100):
......
......@@ -193,27 +193,15 @@ Assuming that one have a graph classification dataset as introduced in
Each item in the graph classification dataset is a pair of a graph and
its label. One can speed up the data loading process by taking advantage
of the DataLoader, by customizing the collate function to batch the
graphs:
.. code:: python
def collate(samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
batched_labels = torch.tensor(labels)
return batched_graph, batched_labels
Then one can create a DataLoader that iterates over the dataset of
of the GraphDataLoader to iterate over the dataset of
graphs in mini-batches.
.. code:: python
from torch.utils.data import DataLoader
dataloader = DataLoader(
from dgl.dataloading import GraphDataLoader
dataloader = GraphDataLoader(
dataset,
batch_size=1024,
collate_fn=collate,
drop_last=False,
shuffle=True)
......
......@@ -24,7 +24,7 @@
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
def _collate_fn(batch):
# 小批次是一个元组(graph, label)列表
......@@ -38,9 +38,9 @@
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
加载 *Node Property Prediction* 数据集类似,但要注意的是这种数据集只有一个图对象。
......
......@@ -82,21 +82,14 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,
import dgl
import torch
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels
# 创建collate_fn函数
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# 创建 dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# 训练
for epoch in range(100):
......
......@@ -143,25 +143,14 @@ DGL内置了常见的图读出函数,例如 :func:`dgl.readout_nodes` 就实
dataset = dgl.data.GINDataset('MUTAG', False)
整图分类数据集里的每个数据点是一个图和它对应标签的元组。为提升数据加载速度,
用户可以DataLoader里自定义collate函数
用户可以调用GraphDataLoader,从而以小批次遍历整个图数据集
.. code:: python
def collate(samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
batched_labels = torch.tensor(labels)
return batched_graph, batched_labels
随后用户可以创建一个以小批次遍历整个图数据集的DataLoader
.. code:: python
from torch.utils.data import DataLoader
dataloader = DataLoader(
from dgl.dataloading import GraphDataLoader
dataloader = GraphDataLoader(
dataset,
batch_size=1024,
collate_fn=collate,
drop_last=False,
shuffle=True)
......
......@@ -110,12 +110,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
pre_process(dataset, prog_args)
# dataset.set_fold(fold)
return torch.utils.data.DataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=True,
num_workers=prog_args.n_worker)
return dgl.dataloading.GraphDataLoader(dataset,
batch_size=prog_args.batch_size,
shuffle=shuffle,
drop_last=True,
num_workers=prog_args.n_worker)
def graph_classify_task(prog_args):
......@@ -191,26 +190,6 @@ def graph_classify_task(prog_args):
print("test accuracy {}%".format(result * 100))
def collate_fn(batch):
'''
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
'''
graphs, labels = map(list, zip(*batch))
#cuda = torch.cuda.is_available()
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
'''
training function
......@@ -233,6 +212,9 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
print("EPOCH ###### {} ######".format(epoch))
computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available():
batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda()
......@@ -283,6 +265,9 @@ def evaluate(dataloader, model, prog_args, logger=None):
correct_label = 0
with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available():
batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda()
......
......@@ -18,11 +18,7 @@ import argparse
from sklearn.metrics import f1_score
from gat import GAT
from dgl.data.ppi import PPIDataset
from torch.utils.data import DataLoader
def collate(graphs):
graph = dgl.batch(graphs)
return graph
from dgl.dataloading import GraphDataLoader
def evaluate(feats, model, subgraph, labels, loss_fcn):
with torch.no_grad():
......@@ -54,9 +50,9 @@ def main(args):
train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size)
valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size)
test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size)
g = train_dataset[0]
n_classes = train_dataset.num_labels
num_feats = g.ndata['feat'].shape[1]
......
......@@ -6,32 +6,18 @@ PyTorch compatible dataloader
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import dgl
from dgl.dataloading import GraphDataLoader
# default collate function
def collate(samples):
# The input `samples` is a list of pairs (graph, label).
graphs, labels = map(list, zip(*samples))
for g in graphs:
# deal with node feats
for key in g.node_attr_schemes().keys():
g.ndata[key] = g.ndata[key].float()
# no edge feats
batched_graph = dgl.batch(graphs)
labels = torch.tensor(labels)
return batched_graph, labels
class GraphDataLoader():
class GINDataLoader():
def __init__(self,
dataset,
batch_size,
device,
collate_fn=collate,
collate_fn=None,
seed=0,
shuffle=True,
split_name='fold10',
......@@ -56,10 +42,10 @@ class GraphDataLoader():
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
self.train_loader = DataLoader(
self.train_loader = GraphDataLoader(
dataset, sampler=train_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
self.valid_loader = DataLoader(
self.valid_loader = GraphDataLoader(
dataset, sampler=valid_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
......
......@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.optim as optim
from dgl.data.gindt import GINDataset
from dataloader import GraphDataLoader, collate
from dataloader import GINDataLoader
from parser import Parser
from gin import GIN
......@@ -22,6 +22,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch):
for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model
for key in graphs.node_attr_schemes().keys():
graphs.ndata[key] = graphs.ndata[key].float()
labels = labels.to(args.device)
feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
......@@ -53,6 +55,8 @@ def eval_net(args, net, dataloader, criterion):
for data in dataloader:
graphs, labels = data
for key in graphs.node_attr_schemes().keys():
graphs.ndata[key] = graphs.ndata[key].float()
feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device)
labels = labels.to(args.device)
......@@ -88,9 +92,9 @@ def main(args):
dataset = GINDataset(args.dataset, not args.learn_eps)
trainloader, validloader = GraphDataLoader(
trainloader, validloader = GINDataLoader(
dataset, batch_size=args.batch_size, device=args.device,
collate_fn=collate, seed=args.seed, shuffle=True,
seed=args.seed, shuffle=True,
split_name='fold10', fold_idx=args.fold_idx).train_valid_loader()
# or split_name='rand', split_ratio=0.7
......
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
import dgl
import numpy as np
def collate_fn(batch):
"""
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
"""
graphs, labels = map(list, zip(*batch))
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
class GraphDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle,
collate_fn=collate_fn, **kwargs)
......@@ -8,9 +8,9 @@ import torch
import torch.nn
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
from dataloader import GraphDataLoader
from network import get_sag_network
from utils import get_stats
......@@ -78,11 +78,14 @@ def parse_args():
def train(model:torch.nn.Module, optimizer, trainloader, device):
model.train()
total_loss = 0.
num_batches = len(trainloader)
for batch in trainloader:
optimizer.zero_grad()
batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device)
batch_labels = batch_labels.long().to(device)
out = model(batch_graphs)
loss = F.nll_loss(out, batch_labels)
loss.backward()
......@@ -90,7 +93,7 @@ def train(model:torch.nn.Module, optimizer, trainloader, device):
total_loss += loss.item()
return total_loss / len(trainloader.dataset)
return total_loss / num_batches
@torch.no_grad()
......@@ -98,11 +101,13 @@ def test(model:torch.nn.Module, loader, device):
model.eval()
correct = 0.
loss = 0.
num_graphs = len(loader.dataset)
num_graphs = len(loader)
for batch in loader:
batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device)
batch_labels = batch_labels.long().to(device)
out = model(batch_graphs)
pred = out.argmax(dim=1)
loss += F.nll_loss(out, batch_labels, reduction="sum").item()
......
......@@ -427,7 +427,7 @@ class GraphDataLoader:
Parameters
----------
collate : Function, default is None
collate_fn : Function, default is None
The customized collate function. Will use the default collate
function if not given.
kwargs : dict
......@@ -445,7 +445,7 @@ class GraphDataLoader:
"""
collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate=None, **kwargs):
def __init__(self, dataset, collate_fn=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -454,10 +454,10 @@ class GraphDataLoader:
else:
dataloader_kwargs[k] = v
if collate is None:
if collate_fn is None:
self.collate = GraphCollator(**collator_kwargs).collate
else:
self.collate = collate
self.collate = collate_fn
self.dataloader = DataLoader(dataset=dataset,
collate_fn=self.collate,
......
......@@ -36,6 +36,8 @@ networks to this problem has been a popular approach recently. This can be seen
# Implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has eight
# different types of graphs and each class has the same number of graph samples.
import dgl
import torch
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
......@@ -68,20 +70,6 @@ plt.show()
# :width: 400pt
# :align: center
#
# Define the following ``collate`` function to form a mini-batch from a given
# list of graph and label pairs.
import dgl
import torch
def collate(samples):
# The input `samples` is a list of pairs
# (graph, label).
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)
###############################################################################
# The return type of :func:`dgl.batch` is still a graph. In the same way,
# a batch of tensors is still a tensor. This means that any code that works
# for one graph immediately works for a batch of graphs. More importantly,
......@@ -149,15 +137,14 @@ class Classifier(nn.Module):
# :math:`80` graphs constitute a test set.
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.dataloading import GraphDataLoader
# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
collate_fn=collate)
# Use DGL's GraphDataLoader. It by default handles the
# graph batching operation for every mini-batch.
data_loader = GraphDataLoader(trainset, batch_size=32, shuffle=True)
# Create model
model = Classifier(1, 256, trainset.num_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment