Unverified Commit a9f8f258 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Doc] Update GraphDataLoader in our docs (#2504)



* update graphdataloader in docs

* fix

* update examples

* fix sagpool
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e0189397
...@@ -26,7 +26,7 @@ Prediction* tasks. ...@@ -26,7 +26,7 @@ Prediction* tasks.
import dgl import dgl
import torch import torch
from ogb.graphproppred import DglGraphPropPredDataset from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
def _collate_fn(batch): def _collate_fn(batch):
...@@ -41,9 +41,9 @@ Prediction* tasks. ...@@ -41,9 +41,9 @@ Prediction* tasks.
dataset = DglGraphPropPredDataset(name='ogbg-molhiv') dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
# dataloader # dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn) train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
Loading *Node Property Prediction* datasets is similar, but note that Loading *Node Property Prediction* datasets is similar, but note that
there is only one graph object in this kind of dataset. there is only one graph object in this kind of dataset.
......
...@@ -96,21 +96,14 @@ follows: ...@@ -96,21 +96,14 @@ follows:
import dgl import dgl
import torch import torch
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
# load data # load data
dataset = QM7bDataset() dataset = QM7bDataset()
num_labels = dataset.num_labels num_labels = dataset.num_labels
# create collate_fn
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# create dataloaders # create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn) dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# training # training
for epoch in range(100): for epoch in range(100):
......
...@@ -193,27 +193,15 @@ Assuming that one have a graph classification dataset as introduced in ...@@ -193,27 +193,15 @@ Assuming that one have a graph classification dataset as introduced in
Each item in the graph classification dataset is a pair of a graph and Each item in the graph classification dataset is a pair of a graph and
its label. One can speed up the data loading process by taking advantage its label. One can speed up the data loading process by taking advantage
of the DataLoader, by customizing the collate function to batch the of the GraphDataLoader to iterate over the dataset of
graphs:
.. code:: python
def collate(samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
batched_labels = torch.tensor(labels)
return batched_graph, batched_labels
Then one can create a DataLoader that iterates over the dataset of
graphs in mini-batches. graphs in mini-batches.
.. code:: python .. code:: python
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
dataloader = DataLoader( dataloader = GraphDataLoader(
dataset, dataset,
batch_size=1024, batch_size=1024,
collate_fn=collate,
drop_last=False, drop_last=False,
shuffle=True) shuffle=True)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
import dgl import dgl
import torch import torch
from ogb.graphproppred import DglGraphPropPredDataset from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
def _collate_fn(batch): def _collate_fn(batch):
# 小批次是一个元组(graph, label)列表 # 小批次是一个元组(graph, label)列表
...@@ -38,9 +38,9 @@ ...@@ -38,9 +38,9 @@
dataset = DglGraphPropPredDataset(name='ogbg-molhiv') dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
# dataloader # dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn) train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
加载 *Node Property Prediction* 数据集类似,但要注意的是这种数据集只有一个图对象。 加载 *Node Property Prediction* 数据集类似,但要注意的是这种数据集只有一个图对象。
......
...@@ -82,21 +82,14 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图, ...@@ -82,21 +82,14 @@ DGL建议让 ``__getitem__(idx)`` 返回如上面代码所示的元组 ``(图,
import dgl import dgl
import torch import torch
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
# 数据导入 # 数据导入
dataset = QM7bDataset() dataset = QM7bDataset()
num_labels = dataset.num_labels num_labels = dataset.num_labels
# 创建collate_fn函数
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# 创建 dataloaders # 创建 dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn) dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# 训练 # 训练
for epoch in range(100): for epoch in range(100):
......
...@@ -143,25 +143,14 @@ DGL内置了常见的图读出函数,例如 :func:`dgl.readout_nodes` 就实 ...@@ -143,25 +143,14 @@ DGL内置了常见的图读出函数,例如 :func:`dgl.readout_nodes` 就实
dataset = dgl.data.GINDataset('MUTAG', False) dataset = dgl.data.GINDataset('MUTAG', False)
整图分类数据集里的每个数据点是一个图和它对应标签的元组。为提升数据加载速度, 整图分类数据集里的每个数据点是一个图和它对应标签的元组。为提升数据加载速度,
用户可以DataLoader里自定义collate函数 用户可以调用GraphDataLoader,从而以小批次遍历整个图数据集
.. code:: python .. code:: python
def collate(samples): from dgl.dataloading import GraphDataLoader
graphs, labels = map(list, zip(*samples)) dataloader = GraphDataLoader(
batched_graph = dgl.batch(graphs)
batched_labels = torch.tensor(labels)
return batched_graph, batched_labels
随后用户可以创建一个以小批次遍历整个图数据集的DataLoader
.. code:: python
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset, dataset,
batch_size=1024, batch_size=1024,
collate_fn=collate,
drop_last=False, drop_last=False,
shuffle=True) shuffle=True)
......
...@@ -110,12 +110,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None): ...@@ -110,12 +110,11 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
pre_process(dataset, prog_args) pre_process(dataset, prog_args)
# dataset.set_fold(fold) # dataset.set_fold(fold)
return torch.utils.data.DataLoader(dataset, return dgl.dataloading.GraphDataLoader(dataset,
batch_size=prog_args.batch_size, batch_size=prog_args.batch_size,
shuffle=shuffle, shuffle=shuffle,
collate_fn=collate_fn, drop_last=True,
drop_last=True, num_workers=prog_args.n_worker)
num_workers=prog_args.n_worker)
def graph_classify_task(prog_args): def graph_classify_task(prog_args):
...@@ -191,26 +190,6 @@ def graph_classify_task(prog_args): ...@@ -191,26 +190,6 @@ def graph_classify_task(prog_args):
print("test accuracy {}%".format(result * 100)) print("test accuracy {}%".format(result * 100))
def collate_fn(batch):
'''
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
'''
graphs, labels = map(list, zip(*batch))
#cuda = torch.cuda.is_available()
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
def train(dataset, model, prog_args, same_feat=True, val_dataset=None): def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
''' '''
training function training function
...@@ -233,6 +212,9 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -233,6 +212,9 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
print("EPOCH ###### {} ######".format(epoch)) print("EPOCH ###### {} ######".format(epoch))
computation_time = 0.0 computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader): for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
batch_graph = batch_graph.to(torch.cuda.current_device()) batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
...@@ -283,6 +265,9 @@ def evaluate(dataloader, model, prog_args, logger=None): ...@@ -283,6 +265,9 @@ def evaluate(dataloader, model, prog_args, logger=None):
correct_label = 0 correct_label = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader): for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items():
batch_graph.ndata[key] = value.float()
graph_labels = graph_labels.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
batch_graph = batch_graph.to(torch.cuda.current_device()) batch_graph = batch_graph.to(torch.cuda.current_device())
graph_labels = graph_labels.cuda() graph_labels = graph_labels.cuda()
......
...@@ -18,11 +18,7 @@ import argparse ...@@ -18,11 +18,7 @@ import argparse
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from gat import GAT from gat import GAT
from dgl.data.ppi import PPIDataset from dgl.data.ppi import PPIDataset
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
def collate(graphs):
graph = dgl.batch(graphs)
return graph
def evaluate(feats, model, subgraph, labels, loss_fcn): def evaluate(feats, model, subgraph, labels, loss_fcn):
with torch.no_grad(): with torch.no_grad():
...@@ -54,9 +50,9 @@ def main(args): ...@@ -54,9 +50,9 @@ def main(args):
train_dataset = PPIDataset(mode='train') train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid') valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test') test_dataset = PPIDataset(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate) train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate) valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate) test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size)
g = train_dataset[0] g = train_dataset[0]
n_classes = train_dataset.num_labels n_classes = train_dataset.num_labels
num_feats = g.ndata['feat'].shape[1] num_feats = g.ndata['feat'].shape[1]
......
...@@ -6,32 +6,18 @@ PyTorch compatible dataloader ...@@ -6,32 +6,18 @@ PyTorch compatible dataloader
import math import math
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import StratifiedKFold
import dgl import dgl
from dgl.dataloading import GraphDataLoader
# default collate function class GINDataLoader():
def collate(samples):
# The input `samples` is a list of pairs (graph, label).
graphs, labels = map(list, zip(*samples))
for g in graphs:
# deal with node feats
for key in g.node_attr_schemes().keys():
g.ndata[key] = g.ndata[key].float()
# no edge feats
batched_graph = dgl.batch(graphs)
labels = torch.tensor(labels)
return batched_graph, labels
class GraphDataLoader():
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
device, device,
collate_fn=collate, collate_fn=None,
seed=0, seed=0,
shuffle=True, shuffle=True,
split_name='fold10', split_name='fold10',
...@@ -56,10 +42,10 @@ class GraphDataLoader(): ...@@ -56,10 +42,10 @@ class GraphDataLoader():
train_sampler = SubsetRandomSampler(train_idx) train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx) valid_sampler = SubsetRandomSampler(valid_idx)
self.train_loader = DataLoader( self.train_loader = GraphDataLoader(
dataset, sampler=train_sampler, dataset, sampler=train_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
self.valid_loader = DataLoader( self.valid_loader = GraphDataLoader(
dataset, sampler=valid_sampler, dataset, sampler=valid_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs) batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from dgl.data.gindt import GINDataset from dgl.data.gindt import GINDataset
from dataloader import GraphDataLoader, collate from dataloader import GINDataLoader
from parser import Parser from parser import Parser
from gin import GIN from gin import GIN
...@@ -22,6 +22,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch): ...@@ -22,6 +22,8 @@ def train(args, net, trainloader, optimizer, criterion, epoch):
for pos, (graphs, labels) in zip(bar, trainloader): for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model # batch graphs will be shipped to device in forward part of model
for key in graphs.node_attr_schemes().keys():
graphs.ndata[key] = graphs.ndata[key].float()
labels = labels.to(args.device) labels = labels.to(args.device)
feat = graphs.ndata.pop('attr').to(args.device) feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device) graphs = graphs.to(args.device)
...@@ -53,6 +55,8 @@ def eval_net(args, net, dataloader, criterion): ...@@ -53,6 +55,8 @@ def eval_net(args, net, dataloader, criterion):
for data in dataloader: for data in dataloader:
graphs, labels = data graphs, labels = data
for key in graphs.node_attr_schemes().keys():
graphs.ndata[key] = graphs.ndata[key].float()
feat = graphs.ndata.pop('attr').to(args.device) feat = graphs.ndata.pop('attr').to(args.device)
graphs = graphs.to(args.device) graphs = graphs.to(args.device)
labels = labels.to(args.device) labels = labels.to(args.device)
...@@ -88,9 +92,9 @@ def main(args): ...@@ -88,9 +92,9 @@ def main(args):
dataset = GINDataset(args.dataset, not args.learn_eps) dataset = GINDataset(args.dataset, not args.learn_eps)
trainloader, validloader = GraphDataLoader( trainloader, validloader = GINDataLoader(
dataset, batch_size=args.batch_size, device=args.device, dataset, batch_size=args.batch_size, device=args.device,
collate_fn=collate, seed=args.seed, shuffle=True, seed=args.seed, shuffle=True,
split_name='fold10', fold_idx=args.fold_idx).train_valid_loader() split_name='fold10', fold_idx=args.fold_idx).train_valid_loader()
# or split_name='rand', split_ratio=0.7 # or split_name='rand', split_ratio=0.7
......
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
import dgl
import numpy as np
def collate_fn(batch):
"""
collate_fn for dataset batching
transform ndata to tensor (in gpu is available)
"""
graphs, labels = map(list, zip(*batch))
# batch graphs and cast to PyTorch tensor
for graph in graphs:
for (key, value) in graph.ndata.items():
graph.ndata[key] = value.float()
batched_graphs = dgl.batch(graphs)
# cast to PyTorch tensor
batched_labels = torch.LongTensor(np.array(labels))
return batched_graphs, batched_labels
class GraphDataLoader(DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
super(GraphDataLoader, self).__init__(dataset, batch_size, shuffle,
collate_fn=collate_fn, **kwargs)
...@@ -8,9 +8,9 @@ import torch ...@@ -8,9 +8,9 @@ import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split from torch.utils.data import random_split
from dataloader import GraphDataLoader
from network import get_sag_network from network import get_sag_network
from utils import get_stats from utils import get_stats
...@@ -78,11 +78,14 @@ def parse_args(): ...@@ -78,11 +78,14 @@ def parse_args():
def train(model:torch.nn.Module, optimizer, trainloader, device): def train(model:torch.nn.Module, optimizer, trainloader, device):
model.train() model.train()
total_loss = 0. total_loss = 0.
num_batches = len(trainloader)
for batch in trainloader: for batch in trainloader:
optimizer.zero_grad() optimizer.zero_grad()
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device) batch_labels = batch_labels.long().to(device)
out = model(batch_graphs) out = model(batch_graphs)
loss = F.nll_loss(out, batch_labels) loss = F.nll_loss(out, batch_labels)
loss.backward() loss.backward()
...@@ -90,7 +93,7 @@ def train(model:torch.nn.Module, optimizer, trainloader, device): ...@@ -90,7 +93,7 @@ def train(model:torch.nn.Module, optimizer, trainloader, device):
total_loss += loss.item() total_loss += loss.item()
return total_loss / len(trainloader.dataset) return total_loss / num_batches
@torch.no_grad() @torch.no_grad()
...@@ -98,11 +101,13 @@ def test(model:torch.nn.Module, loader, device): ...@@ -98,11 +101,13 @@ def test(model:torch.nn.Module, loader, device):
model.eval() model.eval()
correct = 0. correct = 0.
loss = 0. loss = 0.
num_graphs = len(loader.dataset) num_graphs = len(loader)
for batch in loader: for batch in loader:
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
for (key, value) in batch_graphs.ndata.items():
batch_graphs.ndata[key] = value.float()
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.to(device) batch_labels = batch_labels.long().to(device)
out = model(batch_graphs) out = model(batch_graphs)
pred = out.argmax(dim=1) pred = out.argmax(dim=1)
loss += F.nll_loss(out, batch_labels, reduction="sum").item() loss += F.nll_loss(out, batch_labels, reduction="sum").item()
......
...@@ -427,7 +427,7 @@ class GraphDataLoader: ...@@ -427,7 +427,7 @@ class GraphDataLoader:
Parameters Parameters
---------- ----------
collate : Function, default is None collate_fn : Function, default is None
The customized collate function. Will use the default collate The customized collate function. Will use the default collate
function if not given. function if not given.
kwargs : dict kwargs : dict
...@@ -445,7 +445,7 @@ class GraphDataLoader: ...@@ -445,7 +445,7 @@ class GraphDataLoader:
""" """
collator_arglist = inspect.getfullargspec(GraphCollator).args collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate=None, **kwargs): def __init__(self, dataset, collate_fn=None, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -454,10 +454,10 @@ class GraphDataLoader: ...@@ -454,10 +454,10 @@ class GraphDataLoader:
else: else:
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
if collate is None: if collate_fn is None:
self.collate = GraphCollator(**collator_kwargs).collate self.collate = GraphCollator(**collator_kwargs).collate
else: else:
self.collate = collate self.collate = collate_fn
self.dataloader = DataLoader(dataset=dataset, self.dataloader = DataLoader(dataset=dataset,
collate_fn=self.collate, collate_fn=self.collate,
......
...@@ -36,6 +36,8 @@ networks to this problem has been a popular approach recently. This can be seen ...@@ -36,6 +36,8 @@ networks to this problem has been a popular approach recently. This can be seen
# Implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has eight # Implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has eight
# different types of graphs and each class has the same number of graph samples. # different types of graphs and each class has the same number of graph samples.
import dgl
import torch
from dgl.data import MiniGCDataset from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
...@@ -68,20 +70,6 @@ plt.show() ...@@ -68,20 +70,6 @@ plt.show()
# :width: 400pt # :width: 400pt
# :align: center # :align: center
# #
# Define the following ``collate`` function to form a mini-batch from a given
# list of graph and label pairs.
import dgl
import torch
def collate(samples):
# The input `samples` is a list of pairs
# (graph, label).
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels)
###############################################################################
# The return type of :func:`dgl.batch` is still a graph. In the same way, # The return type of :func:`dgl.batch` is still a graph. In the same way,
# a batch of tensors is still a tensor. This means that any code that works # a batch of tensors is still a tensor. This means that any code that works
# for one graph immediately works for a batch of graphs. More importantly, # for one graph immediately works for a batch of graphs. More importantly,
...@@ -149,15 +137,14 @@ class Classifier(nn.Module): ...@@ -149,15 +137,14 @@ class Classifier(nn.Module):
# :math:`80` graphs constitute a test set. # :math:`80` graphs constitute a test set.
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from dgl.dataloading import GraphDataLoader
# Create training and test sets. # Create training and test sets.
trainset = MiniGCDataset(320, 10, 20) trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20) testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function # Use DGL's GraphDataLoader. It by default handles the
# defined before. # graph batching operation for every mini-batch.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True, data_loader = GraphDataLoader(trainset, batch_size=32, shuffle=True)
collate_fn=collate)
# Create model # Create model
model = Classifier(1, 256, trainset.num_classes) model = Classifier(1, 256, trainset.num_classes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment