Unverified Commit 5da3439c authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Feature] Add Device Flag in Data Loaders (#2450)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent b1fb3c15
...@@ -139,36 +139,52 @@ class _EdgeCollator(EdgeCollator): ...@@ -139,36 +139,52 @@ class _EdgeCollator(EdgeCollator):
_pop_blocks_storage(result[-1], self.g_sampling) _pop_blocks_storage(result[-1], self.g_sampling)
return result return result
def _to_device(data, device):
if isinstance(data, dict):
for k, v in data.items():
data[k] = v.to(device)
elif isinstance(data, list):
data = [item.to(device) for item in data]
else:
data = data.to(device)
return data
class _NodeDataLoaderIter: class _NodeDataLoaderIter:
def __init__(self, node_dataloader): def __init__(self, node_dataloader):
self.device = node_dataloader.device
self.node_dataloader = node_dataloader self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader) self.iter_ = iter(node_dataloader.dataloader)
def __next__(self): def __next__(self):
# input_nodes, output_nodes, [items], blocks # input_nodes, output_nodes, [items], blocks
result = next(self.iter_) result_ = next(self.iter_)
_restore_blocks_storage(result[-1], self.node_dataloader.collator.g) _restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)
result = []
for data in result_:
result.append(_to_device(data, self.device))
return result return result
class _EdgeDataLoaderIter: class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader): def __init__(self, edge_dataloader):
self.device = edge_dataloader.device
self.edge_dataloader = edge_dataloader self.edge_dataloader = edge_dataloader
self.iter_ = iter(edge_dataloader.dataloader) self.iter_ = iter(edge_dataloader.dataloader)
def __next__(self): def __next__(self):
if self.edge_dataloader.collator.negative_sampler is None: result_ = next(self.iter_)
# input_nodes, pair_graph, [items], blocks
result = next(self.iter_) if self.edge_dataloader.collator.negative_sampler is not None:
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
else:
# input_nodes, pair_graph, neg_pair_graph, [items], blocks # input_nodes, pair_graph, neg_pair_graph, [items], blocks
result = next(self.iter_) # Otherwise, input_nodes, pair_graph, [items], blocks
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g) _restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result[2], self.edge_dataloader.collator.g) _restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling) _restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)
return result
result = []
for data in result_:
result.append(_to_device(data, self.device))
return result
class NodeDataLoader: class NodeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
...@@ -182,6 +198,9 @@ class NodeDataLoader: ...@@ -182,6 +198,9 @@ class NodeDataLoader:
The node set to compute outputs. The node set to compute outputs.
block_sampler : dgl.dataloading.BlockSampler block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
device : device context, optional
The device of the generated blocks in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
kwargs : dict kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`. Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...@@ -200,7 +219,7 @@ class NodeDataLoader: ...@@ -200,7 +219,7 @@ class NodeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(NodeCollator).args collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, block_sampler, **kwargs): def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -210,6 +229,7 @@ class NodeDataLoader: ...@@ -210,6 +229,7 @@ class NodeDataLoader:
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
if isinstance(g, DistGraph): if isinstance(g, DistGraph):
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
# Distributed DataLoader currently does not support heterogeneous graphs # Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution # and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs) self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
...@@ -224,6 +244,7 @@ class NodeDataLoader: ...@@ -224,6 +244,7 @@ class NodeDataLoader:
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs)
self.is_distributed = False self.is_distributed = False
self.device = device
def __iter__(self): def __iter__(self):
"""Return the iterator of the data loader.""" """Return the iterator of the data loader."""
...@@ -267,6 +288,9 @@ class EdgeDataLoader: ...@@ -267,6 +288,9 @@ class EdgeDataLoader:
The edge set in graph :attr:`g` to compute outputs. The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
device : device context, optional
The device of the generated blocks and graphs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
g_sampling : DGLGraph, optional g_sampling : DGLGraph, optional
The graph where neighborhood sampling is performed. The graph where neighborhood sampling is performed.
...@@ -397,7 +421,7 @@ class EdgeDataLoader: ...@@ -397,7 +421,7 @@ class EdgeDataLoader:
""" """
collator_arglist = inspect.getfullargspec(EdgeCollator).args collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, block_sampler, **kwargs): def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -412,6 +436,7 @@ class EdgeDataLoader: ...@@ -412,6 +436,7 @@ class EdgeDataLoader:
+ 'Please use DistDataLoader directly.' + 'Please use DistDataLoader directly.'
self.dataloader = DataLoader( self.dataloader = DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
self.device = device
def __iter__(self): def __iter__(self):
"""Return the iterator of the data loader.""" """Return the iterator of the data loader."""
......
import dgl import dgl
import backend as F import backend as F
import numpy as np
import unittest import unittest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from collections import defaultdict from collections import defaultdict
...@@ -217,6 +216,127 @@ def test_graph_dataloader(): ...@@ -217,6 +216,127 @@ def test_graph_dataloader():
assert isinstance(graph, dgl.DGLGraph) assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size assert F.asnumpy(label).shape[0] == batch_size
def _check_device(data):
if isinstance(data, dict):
for k, v in data.items():
assert v.device == F.ctx()
elif isinstance(data, list):
for v in data:
assert v.device == F.ctx()
else:
assert data.device == F.ctx()
def test_node_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes(), return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)
def test_edge_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
# return_indices = False & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(), batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(blocks)
# return_indices = False & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(),
negative_sampler=neg_sampler, batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(blocks)
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
# return_indices = True & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(items)
_check_device(blocks)
# return_indices = True & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), negative_sampler=neg_sampler,
batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, neg_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(items)
_check_device(blocks)
if __name__ == '__main__': if __name__ == '__main__':
test_neighbor_sampler_dataloader() test_neighbor_sampler_dataloader()
test_graph_dataloader() test_graph_dataloader()
test_node_dataloader()
test_edge_dataloader()
...@@ -9,7 +9,6 @@ from test_utils.graph_cases import get_cases, random_graph, random_bipartite, ra ...@@ -9,7 +9,6 @@ from test_utils.graph_cases import get_cases, random_graph, random_bipartite, ra
from test_utils import parametrize_dtype from test_utils import parametrize_dtype
from copy import deepcopy from copy import deepcopy
import numpy as np
import scipy as sp import scipy as sp
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment