Unverified Commit 5da3439c authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Feature] Add Device Flag in Data Loaders (#2450)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent b1fb3c15
......@@ -139,36 +139,52 @@ class _EdgeCollator(EdgeCollator):
_pop_blocks_storage(result[-1], self.g_sampling)
return result
def _to_device(data, device):
if isinstance(data, dict):
for k, v in data.items():
data[k] = v.to(device)
elif isinstance(data, list):
data = [item.to(device) for item in data]
else:
data = data.to(device)
return data
class _NodeDataLoaderIter:
def __init__(self, node_dataloader):
self.device = node_dataloader.device
self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader)
def __next__(self):
# input_nodes, output_nodes, [items], blocks
result = next(self.iter_)
_restore_blocks_storage(result[-1], self.node_dataloader.collator.g)
result_ = next(self.iter_)
_restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)
result = []
for data in result_:
result.append(_to_device(data, self.device))
return result
class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader):
self.device = edge_dataloader.device
self.edge_dataloader = edge_dataloader
self.iter_ = iter(edge_dataloader.dataloader)
def __next__(self):
if self.edge_dataloader.collator.negative_sampler is None:
# input_nodes, pair_graph, [items], blocks
result = next(self.iter_)
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
else:
result_ = next(self.iter_)
if self.edge_dataloader.collator.negative_sampler is not None:
# input_nodes, pair_graph, neg_pair_graph, [items], blocks
result = next(self.iter_)
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result[2], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
# Otherwise, input_nodes, pair_graph, [items], blocks
_restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)
result = []
for data in result_:
result.append(_to_device(data, self.device))
return result
class NodeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list
......@@ -182,6 +198,9 @@ class NodeDataLoader:
The node set to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
device : device context, optional
The device of the generated blocks in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
......@@ -200,7 +219,7 @@ class NodeDataLoader:
"""
collator_arglist = inspect.getfullargspec(NodeCollator).args
def __init__(self, g, nids, block_sampler, **kwargs):
def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -210,6 +229,7 @@ class NodeDataLoader:
dataloader_kwargs[k] = v
if isinstance(g, DistGraph):
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
......@@ -224,6 +244,7 @@ class NodeDataLoader:
collate_fn=self.collator.collate,
**dataloader_kwargs)
self.is_distributed = False
self.device = device
def __iter__(self):
"""Return the iterator of the data loader."""
......@@ -267,6 +288,9 @@ class EdgeDataLoader:
The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
device : device context, optional
The device of the generated blocks and graphs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
g_sampling : DGLGraph, optional
The graph where neighborhood sampling is performed.
......@@ -397,7 +421,7 @@ class EdgeDataLoader:
"""
collator_arglist = inspect.getfullargspec(EdgeCollator).args
def __init__(self, g, eids, block_sampler, **kwargs):
def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
......@@ -412,6 +436,7 @@ class EdgeDataLoader:
+ 'Please use DistDataLoader directly.'
self.dataloader = DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
self.device = device
def __iter__(self):
"""Return the iterator of the data loader."""
......
import dgl
import backend as F
import numpy as np
import unittest
from torch.utils.data import DataLoader
from collections import defaultdict
......@@ -217,6 +216,127 @@ def test_graph_dataloader():
assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size
def _check_device(data):
if isinstance(data, dict):
for k, v in data.items():
assert v.device == F.ctx()
elif isinstance(data, list):
for v in data:
assert v.device == F.ctx()
else:
assert data.device == F.ctx()
def test_node_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes(), return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)
# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)
def test_edge_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)
g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
# return_indices = False & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(), batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(blocks)
# return_indices = False & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(),
negative_sampler=neg_sampler, batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(blocks)
g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
# return_indices = True & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(items)
_check_device(blocks)
# return_indices = True & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), negative_sampler=neg_sampler,
batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, neg_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(items)
_check_device(blocks)
if __name__ == '__main__':
test_neighbor_sampler_dataloader()
test_graph_dataloader()
test_node_dataloader()
test_edge_dataloader()
......@@ -9,7 +9,6 @@ from test_utils.graph_cases import get_cases, random_graph, random_bipartite, ra
from test_utils import parametrize_dtype
from copy import deepcopy
import numpy as np
import scipy as sp
def _AXWb(A, X, W, b):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment