Unverified Commit bdb88df5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Performance] Allow DataLoaders to return indices of the seed array (#2495)



* return indices from dataloader

* fixes

* fix

* fix distgraph and added some todos

* Update dataloader.py

* Update dataloader.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 79940695
...@@ -287,6 +287,25 @@ class Collator(ABC): ...@@ -287,6 +287,25 @@ class Collator(ABC):
""" """
raise NotImplementedError raise NotImplementedError
# TODO(BarclayII): DistGraph.idtype and DistGraph.device are in the code, however
# the underlying DGLGraph object could be None. I was unable to figure out how
# to properly implement those two properties so I'm working around that. If the
# graph is a DistGraph, I assume that the dtype and device of the data should
# be the same as the graph already.
#
# After idtype and device get properly implemented, we should remove these two
# _prepare_* functions.
def _prepare_tensor_dict(g, data, name, is_distributed):
if is_distributed:
x = F.tensor(next(iter(data.values())))
return {k: F.copy_to(F.astype(v, F.dtype(x)), F.context(x)) for k, v in data.items()}
else:
return utils.prepare_tensor_dict(g, data, name)
def _prepare_tensor(g, data, name, is_distributed):
return F.tensor(data) if is_distributed else utils.prepare_tensor(g, data, name)
class NodeCollator(Collator): class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for """DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling. training node classification or regression on a single graph with neighborhood sampling.
...@@ -299,6 +318,9 @@ class NodeCollator(Collator): ...@@ -299,6 +318,9 @@ class NodeCollator(Collator):
The node set to compute outputs. The node set to compute outputs.
block_sampler : dgl.dataloading.BlockSampler block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler. The neighborhood sampler.
return_eids : bool, default False
Whether to additionally return the indices of the input ``nids`` array sampled in the
minibatch.
Examples Examples
-------- --------
...@@ -314,19 +336,24 @@ class NodeCollator(Collator): ...@@ -314,19 +336,24 @@ class NodeCollator(Collator):
>>> for input_nodes, output_nodes, blocks in dataloader: >>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks) ... train_on(input_nodes, output_nodes, blocks)
""" """
def __init__(self, g, nids, block_sampler): def __init__(self, g, nids, block_sampler, return_indices=False):
self.g = g self.g = g
self._is_distributed = isinstance(g, DistGraph) self._is_distributed = isinstance(g, DistGraph)
if not isinstance(nids, Mapping): if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \ assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types" "nids should be a dict of node type and ids for graph with multiple node types"
self.nids = nids
self.block_sampler = block_sampler self.block_sampler = block_sampler
self.return_indices = return_indices
if isinstance(nids, Mapping): if isinstance(nids, Mapping):
self._dataset = utils.FlattenedDict(nids) self.nids = _prepare_tensor_dict(g, nids, 'nids', self._is_distributed)
dataset = {k: F.arange(0, len(v), F.dtype(v), F.context(v))
for k, v in self.nids.items()} if return_indices else self.nids
self._dataset = utils.FlattenedDict(dataset)
else: else:
self._dataset = nids self.nids = _prepare_tensor(g, nids, 'nids', self._is_distributed)
self._dataset = F.arange(0, len(nids), F.dtype(nids), F.context(nids)) \
if return_indices else nids
@property @property
def dataset(self): def dataset(self):
...@@ -342,6 +369,9 @@ class NodeCollator(Collator): ...@@ -342,6 +369,9 @@ class NodeCollator(Collator):
Either a list of node IDs (for homogeneous graphs), or a list of node type-ID Either a list of node IDs (for homogeneous graphs), or a list of node type-ID
pairs (for heterogeneous graphs). pairs (for heterogeneous graphs).
If ``return_indices`` is True, represents the indices to the seed node
array(s) instead.
Returns Returns
------- -------
input_nodes : Tensor or dict[ntype, Tensor] input_nodes : Tensor or dict[ntype, Tensor]
...@@ -354,26 +384,34 @@ class NodeCollator(Collator): ...@@ -354,26 +384,34 @@ class NodeCollator(Collator):
If the original graph has multiple node types, return a dictionary of If the original graph has multiple node types, return a dictionary of
node type names and node ID tensors. Otherwise, return a single tensor. node type names and node ID tensors. Otherwise, return a single tensor.
indices : Tensor or dict[ntype, Tensor], optional
The indices of the sampled nodes in the ``nids`` member.
Only returned if ``return_indices`` is True.
blocks : list[DGLGraph] blocks : list[DGLGraph]
The list of blocks necessary for computing the representation. The list of blocks necessary for computing the representation.
""" """
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = _prepare_tensor_dict(self.g, items, 'items', self._is_distributed)
else:
items = _prepare_tensor(self.g, items, 'items', self._is_distributed)
# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
# this function does not work. I'm again skipping this step as a workaround.
# We need to fix this.
if not self._is_distributed:
if isinstance(items, dict): if isinstance(items, dict):
items = utils.prepare_tensor_dict(self.g, items, 'items') sample_items = {k: F.gather_row(self.nids[k], v) for k, v in items.items()} \
if self.return_indices else items
else: else:
items = utils.prepare_tensor(self.g, items, 'items') sample_items = F.gather_row(self.nids, items) if self.return_indices else items
blocks = self.block_sampler.sample_blocks(self.g, items)
blocks = self.block_sampler.sample_blocks(self.g, sample_items)
output_nodes = blocks[-1].dstdata[NID] output_nodes = blocks[-1].dstdata[NID]
input_nodes = blocks[0].srcdata[NID] input_nodes = blocks[0].srcdata[NID]
if not self.return_indices:
return input_nodes, output_nodes, blocks return input_nodes, output_nodes, blocks
else:
return input_nodes, output_nodes, items, blocks
class EdgeCollator(Collator): class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for """DGL collator to combine edges and their computation dependencies within a minibatch for
...@@ -453,6 +491,9 @@ class EdgeCollator(Collator): ...@@ -453,6 +491,9 @@ class EdgeCollator(Collator):
A set of builtin negative samplers are provided in A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`. :ref:`the negative sampling module <api-dataloading-negative-sampling>`.
return_eids : bool, default False
Whether to additionally return the indices of the input ``eids`` array sampled in the
minibatch.
Examples Examples
-------- --------
...@@ -541,13 +582,15 @@ class EdgeCollator(Collator): ...@@ -541,13 +582,15 @@ class EdgeCollator(Collator):
... train_on(input_nodse, pair_graph, neg_pair_graph, blocks) ... train_on(input_nodse, pair_graph, neg_pair_graph, blocks)
""" """
def __init__(self, g, eids, block_sampler, g_sampling=None, exclude=None, def __init__(self, g, eids, block_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None): reverse_eids=None, reverse_etypes=None, negative_sampler=None,
return_indices=False):
self.g = g self.g = g
self._is_distributed = isinstance(g, DistGraph)
if not isinstance(eids, Mapping): if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \ assert len(g.etypes) == 1, \
"eids should be a dict of etype and ids for graph with multiple etypes" "eids should be a dict of etype and ids for graph with multiple etypes"
self.eids = eids
self.block_sampler = block_sampler self.block_sampler = block_sampler
self.return_indices = return_indices
# One may wish to iterate over the edges in one graph while perform sampling in # One may wish to iterate over the edges in one graph while perform sampling in
# another graph. This may be the case for iterating over validation and test # another graph. This may be the case for iterating over validation and test
...@@ -566,9 +609,14 @@ class EdgeCollator(Collator): ...@@ -566,9 +609,14 @@ class EdgeCollator(Collator):
self.negative_sampler = negative_sampler self.negative_sampler = negative_sampler
if isinstance(eids, Mapping): if isinstance(eids, Mapping):
self._dataset = utils.FlattenedDict(eids) self.eids = _prepare_tensor_dict(g, eids, 'eids', self._is_distributed)
dataset = {k: F.arange(0, len(v), F.dtype(v), F.context(v))
for k, v in self.eids.items()} if return_indices else self.eids
self._dataset = utils.FlattenedDict(dataset)
else: else:
self._dataset = eids self.eids = _prepare_tensor(g, eids, 'eids', self._is_distributed)
self._dataset = F.arange(0, len(eids), F.dtype(eids), F.context(eids)) \
if return_indices else eids
@property @property
def dataset(self): def dataset(self):
...@@ -578,17 +626,23 @@ class EdgeCollator(Collator): ...@@ -578,17 +626,23 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items') items = _prepare_tensor_dict(self.g_sampling, items, 'items', self._is_distributed)
else: else:
items = utils.prepare_tensor(self.g_sampling, items, 'items') items = _prepare_tensor(self.g_sampling, items, 'items', self._is_distributed)
pair_graph = self.g.edge_subgraph(items) if isinstance(items, dict):
sample_items = {k: F.gather_row(self.eids[k], v) for k, v in items.items()} \
if self.return_indices else items
else:
sample_items = F.gather_row(self.eids, items) if self.return_indices else items
pair_graph = self.g.edge_subgraph(sample_items)
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
exclude_eids = _find_exclude_eids( exclude_eids = _find_exclude_eids(
self.g, self.g,
self.exclude, self.exclude,
items, sample_items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes)
...@@ -596,20 +650,29 @@ class EdgeCollator(Collator): ...@@ -596,20 +650,29 @@ class EdgeCollator(Collator):
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
input_nodes = blocks[0].srcdata[NID] input_nodes = blocks[0].srcdata[NID]
if not self.return_indices:
return input_nodes, pair_graph, blocks return input_nodes, pair_graph, blocks
else:
return input_nodes, pair_graph, items, blocks
def _collate_with_negative_sampling(self, items): def _collate_with_negative_sampling(self, items):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items') items = _prepare_tensor_dict(self.g_sampling, items, 'items', self._is_distributed)
else:
items = _prepare_tensor(self.g_sampling, items, 'items', self._is_distributed)
if isinstance(items, dict):
sample_items = {k: F.gather_row(self.eids[k], v) for k, v in items.items()} \
if self.return_indices else items
else: else:
items = utils.prepare_tensor(self.g_sampling, items, 'items') sample_items = F.gather_row(self.eids, items) if self.return_indices else items
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True) pair_graph = self.g.edge_subgraph(sample_items, preserve_nodes=True)
induced_edges = pair_graph.edata[EID] induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items) neg_srcdst = self.negative_sampler(self.g, sample_items)
if not isinstance(neg_srcdst, Mapping): if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \ assert len(self.g.etypes) == 1, \
'graph has multiple or no edge types; '\ 'graph has multiple or no edge types; '\
...@@ -631,7 +694,7 @@ class EdgeCollator(Collator): ...@@ -631,7 +694,7 @@ class EdgeCollator(Collator):
exclude_eids = _find_exclude_eids( exclude_eids = _find_exclude_eids(
self.g, self.g,
self.exclude, self.exclude,
items, sample_items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes)
...@@ -639,7 +702,10 @@ class EdgeCollator(Collator): ...@@ -639,7 +702,10 @@ class EdgeCollator(Collator):
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
input_nodes = blocks[0].srcdata[NID] input_nodes = blocks[0].srcdata[NID]
if not self.return_indices:
return input_nodes, pair_graph, neg_pair_graph, blocks return input_nodes, pair_graph, neg_pair_graph, blocks
else:
return input_nodes, pair_graph, neg_pair_graph, items, blocks
def collate(self, items): def collate(self, items):
"""Combines the sampled edges into a minibatch for edge classification, edge """Combines the sampled edges into a minibatch for edge classification, edge
...@@ -651,6 +717,9 @@ class EdgeCollator(Collator): ...@@ -651,6 +717,9 @@ class EdgeCollator(Collator):
Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID Either a list of edge IDs (for homogeneous graphs), or a list of edge type-ID
pairs (for heterogeneous graphs). pairs (for heterogeneous graphs).
If ``return_indices`` is True, represents the indices to the seed edge
array(s) instead.
Returns Returns
------- -------
Either ``(input_nodes, pair_graph, blocks)``, or Either ``(input_nodes, pair_graph, blocks)``, or
...@@ -674,6 +743,10 @@ class EdgeCollator(Collator): ...@@ -674,6 +743,10 @@ class EdgeCollator(Collator):
Note that the metagraph of this graph will be identical to that of the original Note that the metagraph of this graph will be identical to that of the original
graph. graph.
items : Tensor or dict[ntype, Tensor]
The indices of the sampled edges in the ``eids`` member.
Only returned if ``return_indices`` is True.
blocks : list[DGLGraph] blocks : list[DGLGraph]
The list of blocks necessary for computing the representation of the edges. The list of blocks necessary for computing the representation of the edges.
""" """
......
...@@ -118,23 +118,26 @@ def _restore_blocks_storage(blocks, g): ...@@ -118,23 +118,26 @@ def _restore_blocks_storage(blocks, g):
class _NodeCollator(NodeCollator): class _NodeCollator(NodeCollator):
def collate(self, items): def collate(self, items):
input_nodes, output_nodes, blocks = super().collate(items) # input_nodes, output_nodes, [items], blocks
_pop_blocks_storage(blocks, self.g) result = super().collate(items)
return input_nodes, output_nodes, blocks _pop_blocks_storage(result[-1], self.g)
return result
class _EdgeCollator(EdgeCollator): class _EdgeCollator(EdgeCollator):
def collate(self, items): def collate(self, items):
if self.negative_sampler is None: if self.negative_sampler is None:
input_nodes, pair_graph, blocks = super().collate(items) # input_nodes, pair_graph, [items], blocks
_pop_subgraph_storage(pair_graph, self.g) result = super().collate(items)
_pop_blocks_storage(blocks, self.g_sampling) _pop_subgraph_storage(result[1], self.g)
return input_nodes, pair_graph, blocks _pop_blocks_storage(result[-1], self.g_sampling)
return result
else: else:
input_nodes, pair_graph, neg_pair_graph, blocks = super().collate(items) # input_nodes, pair_graph, neg_pair_graph, [items], blocks
_pop_subgraph_storage(pair_graph, self.g) result = super().collate(items)
_pop_subgraph_storage(neg_pair_graph, self.g) _pop_subgraph_storage(result[1], self.g)
_pop_blocks_storage(blocks, self.g_sampling) _pop_subgraph_storage(result[2], self.g)
return input_nodes, pair_graph, neg_pair_graph, blocks _pop_blocks_storage(result[-1], self.g_sampling)
return result
class _NodeDataLoaderIter: class _NodeDataLoaderIter:
def __init__(self, node_dataloader): def __init__(self, node_dataloader):
...@@ -142,9 +145,10 @@ class _NodeDataLoaderIter: ...@@ -142,9 +145,10 @@ class _NodeDataLoaderIter:
self.iter_ = iter(node_dataloader.dataloader) self.iter_ = iter(node_dataloader.dataloader)
def __next__(self): def __next__(self):
input_nodes, output_nodes, blocks = next(self.iter_) # input_nodes, output_nodes, [items], blocks
_restore_blocks_storage(blocks, self.node_dataloader.collator.g) result = next(self.iter_)
return input_nodes, output_nodes, blocks _restore_blocks_storage(result[-1], self.node_dataloader.collator.g)
return result
class _EdgeDataLoaderIter: class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader): def __init__(self, edge_dataloader):
...@@ -153,16 +157,18 @@ class _EdgeDataLoaderIter: ...@@ -153,16 +157,18 @@ class _EdgeDataLoaderIter:
def __next__(self): def __next__(self):
if self.edge_dataloader.collator.negative_sampler is None: if self.edge_dataloader.collator.negative_sampler is None:
input_nodes, pair_graph, blocks = next(self.iter_) # input_nodes, pair_graph, [items], blocks
_restore_subgraph_storage(pair_graph, self.edge_dataloader.collator.g) result = next(self.iter_)
_restore_blocks_storage(blocks, self.edge_dataloader.collator.g_sampling) _restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
return input_nodes, pair_graph, blocks _restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
else: else:
input_nodes, pair_graph, neg_pair_graph, blocks = next(self.iter_) # input_nodes, pair_graph, neg_pair_graph, [items], blocks
_restore_subgraph_storage(pair_graph, self.edge_dataloader.collator.g) result = next(self.iter_)
_restore_subgraph_storage(neg_pair_graph, self.edge_dataloader.collator.g) _restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(blocks, self.edge_dataloader.collator.g_sampling) _restore_subgraph_storage(result[2], self.edge_dataloader.collator.g)
return input_nodes, pair_graph, neg_pair_graph, blocks _restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
class NodeDataLoader: class NodeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
......
...@@ -343,7 +343,7 @@ class DistGraph: ...@@ -343,7 +343,7 @@ class DistGraph:
The example shows the creation of ``DistGraph`` in the standalone mode. The example shows the creation of ``DistGraph`` in the standalone mode.
>>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis', >>> dgl.distributed.partition_graph(g, 'graph_name', 1, num_hops=1, part_method='metis',
out_path='output/', reshuffle=True) ... out_path='output/', reshuffle=True)
>>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json') >>> g = dgl.distributed.DistGraph('graph_name', part_config='output/graph_name.json')
The example shows the creation of ``DistGraph`` in the distributed mode. The example shows the creation of ``DistGraph`` in the distributed mode.
...@@ -357,7 +357,7 @@ class DistGraph: ...@@ -357,7 +357,7 @@ class DistGraph:
... frontier = dgl.distributed.sample_neighbors(g, seeds, 10) ... frontier = dgl.distributed.sample_neighbors(g, seeds, 10)
... return dgl.to_block(frontier, seeds) ... return dgl.to_block(frontier, seeds)
>>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000, >>> dataloader = dgl.distributed.DistDataLoader(dataset=nodes, batch_size=1000,
collate_fn=sample, shuffle=True) ... collate_fn=sample, shuffle=True)
>>> for block in dataloader: >>> for block in dataloader:
... feat = g.ndata['features'][block.srcdata[dgl.NID]] ... feat = g.ndata['features'][block.srcdata[dgl.NID]]
... labels = g.ndata['labels'][block.dstdata[dgl.NID]] ... labels = g.ndata['labels'][block.dstdata[dgl.NID]]
...@@ -492,6 +492,7 @@ class DistGraph: ...@@ -492,6 +492,7 @@ class DistGraph:
long long
int int
""" """
# TODO(da?): describe when self._g is None and idtype shouldn't be called.
return self._g.idtype return self._g.idtype
@property @property
...@@ -513,6 +514,7 @@ class DistGraph: ...@@ -513,6 +514,7 @@ class DistGraph:
------- -------
Device context object Device context object
""" """
# TODO(da?): describe when self._g is None and device shouldn't be called.
return self._g.device return self._g.device
@property @property
...@@ -788,7 +790,7 @@ def _split_even(partition_book, rank, elements): ...@@ -788,7 +790,7 @@ def _split_even(partition_book, rank, elements):
# here we divide the element list as evenly as possible. If we use range partitioning, # here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default # the split results also respect the data locality. Range partitioning is the default
# strategy. # strategy.
# TODO(zhegnda) we need another way to divide the list for other partitioning strategy. # TODO(zhengda) we need another way to divide the list for other partitioning strategy.
# compute the offset of each split and ensure that the difference of each partition size # compute the offset of each split and ensure that the difference of each partition size
# is 1. # is 1.
......
...@@ -6,21 +6,41 @@ from torch.utils.data import DataLoader ...@@ -6,21 +6,41 @@ from torch.utils.data import DataLoader
from collections import defaultdict from collections import defaultdict
from itertools import product from itertools import product
def _check_neighbor_sampling_dataloader(g, nids, dl, mode): def _check_neighbor_sampling_dataloader(g, nids, dl, mode, collator):
seeds = defaultdict(list) seeds = defaultdict(list)
for item in dl: for item in dl:
if mode == 'node': if mode == 'node':
input_nodes, output_nodes, blocks = item input_nodes, output_nodes, items, blocks = item
elif mode == 'edge': elif mode == 'edge':
input_nodes, pair_graph, blocks = item input_nodes, pair_graph, items, blocks = item
output_nodes = pair_graph.ndata[dgl.NID] output_nodes = pair_graph.ndata[dgl.NID]
elif mode == 'link': elif mode == 'link':
input_nodes, pair_graph, neg_graph, blocks = item input_nodes, pair_graph, neg_graph, items, blocks = item
output_nodes = pair_graph.ndata[dgl.NID] output_nodes = pair_graph.ndata[dgl.NID]
for ntype in pair_graph.ntypes: for ntype in pair_graph.ntypes:
assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID]) assert F.array_equal(pair_graph.nodes[ntype].data[dgl.NID], neg_graph.nodes[ntype].data[dgl.NID])
# TODO: check if items match output nodes/edges
if mode == 'node':
if len(g.ntypes) > 1:
for ntype in g.ntypes:
if ntype not in items:
assert len(output_nodes[ntype]) == 0
else:
assert F.array_equal(output_nodes[ntype], F.gather_row(collator.nids[ntype], items[ntype]))
else:
assert F.array_equal(output_nodes, F.gather_row(collator.nids, items))
else:
if len(g.etypes) > 1:
for etype, eids in collator.eids.items():
if etype not in items:
assert pair_graph.num_edges(etype=etype) == 0
else:
assert F.array_equal(pair_graph.edges[etype].data[dgl.EID], F.gather_row(eids, items[etype]))
else:
assert F.array_equal(pair_graph.edata[dgl.EID], F.gather_row(collator.eids, items))
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
for ntype in g.ntypes: for ntype in g.ntypes:
assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID]) assert F.array_equal(input_nodes[ntype], blocks[0].srcnodes[ntype].data[dgl.NID])
...@@ -28,6 +48,7 @@ def _check_neighbor_sampling_dataloader(g, nids, dl, mode): ...@@ -28,6 +48,7 @@ def _check_neighbor_sampling_dataloader(g, nids, dl, mode):
else: else:
assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID]) assert F.array_equal(input_nodes, blocks[0].srcdata[dgl.NID])
assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID]) assert F.array_equal(output_nodes, blocks[-1].dstdata[dgl.NID])
prev_dst = {ntype: None for ntype in g.ntypes} prev_dst = {ntype: None for ntype in g.ntypes}
for block in blocks: for block in blocks:
for canonical_etype in block.canonical_etypes: for canonical_etype in block.canonical_etypes:
...@@ -110,31 +131,34 @@ def test_neighbor_sampler_dataloader(): ...@@ -110,31 +131,34 @@ def test_neighbor_sampler_dataloader():
for seeds, sampler in product( for seeds, sampler in product(
[F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)], [F.tensor([0, 1, 2, 3, 5], dtype=F.int64), F.tensor([4, 5], dtype=F.int64)],
[g_sampler1, g_sampler2]): [g_sampler1, g_sampler2]):
collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler)) collators.append(dgl.dataloading.NodeCollator(g, seeds, sampler, return_indices=True))
graphs.append(g) graphs.append(g)
nids.append({'user': seeds}) nids.append({'user': seeds})
modes.append('node') modes.append('node')
collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler)) collators.append(dgl.dataloading.EdgeCollator(g, seeds, sampler, return_indices=True))
graphs.append(g) graphs.append(g)
nids.append({'follow': seeds}) nids.append({'follow': seeds})
modes.append('edge') modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids)) g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
return_indices=True))
graphs.append(g) graphs.append(g)
nids.append({'follow': seeds}) nids.append({'follow': seeds})
modes.append('edge') modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) g, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
return_indices=True))
graphs.append(g) graphs.append(g)
nids.append({'follow': seeds}) nids.append({'follow': seeds})
modes.append('link') modes.append('link')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids, g, seeds, sampler, exclude='reverse_id', reverse_eids=reverse_eids,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
return_indices=True))
graphs.append(g) graphs.append(g)
nids.append({'follow': seeds}) nids.append({'follow': seeds})
modes.append('link') modes.append('link')
...@@ -143,7 +167,7 @@ def test_neighbor_sampler_dataloader(): ...@@ -143,7 +167,7 @@ def test_neighbor_sampler_dataloader():
[{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}, [{'user': F.tensor([0, 1, 3, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)},
{'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}], {'user': F.tensor([4, 5], dtype=F.int64), 'game': F.tensor([0, 1, 2], dtype=F.int64)}],
[hg_sampler1, hg_sampler2]): [hg_sampler1, hg_sampler2]):
collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler)) collators.append(dgl.dataloading.NodeCollator(hg, seeds, sampler, return_indices=True))
graphs.append(hg) graphs.append(hg)
nids.append(seeds) nids.append(seeds)
modes.append('node') modes.append('node')
...@@ -152,26 +176,29 @@ def test_neighbor_sampler_dataloader(): ...@@ -152,26 +176,29 @@ def test_neighbor_sampler_dataloader():
[{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}, [{'follow': F.tensor([0, 1, 3, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)},
{'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}], {'follow': F.tensor([4, 5], dtype=F.int64), 'play': F.tensor([1, 3], dtype=F.int64)}],
[hg_sampler1, hg_sampler2]): [hg_sampler1, hg_sampler2]):
collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler)) collators.append(dgl.dataloading.EdgeCollator(hg, seeds, sampler, return_indices=True))
graphs.append(hg) graphs.append(hg)
nids.append(seeds) nids.append(seeds)
modes.append('edge') modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes)) hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
return_indices=True))
graphs.append(hg) graphs.append(hg)
nids.append(seeds) nids.append(seeds)
modes.append('edge') modes.append('edge')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) hg, seeds, sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
return_indices=True))
graphs.append(hg) graphs.append(hg)
nids.append(seeds) nids.append(seeds)
modes.append('link') modes.append('link')
collators.append(dgl.dataloading.EdgeCollator( collators.append(dgl.dataloading.EdgeCollator(
hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes, hg, seeds, sampler, exclude='reverse_types', reverse_etypes=reverse_etypes,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(2))) negative_sampler=dgl.dataloading.negative_sampler.Uniform(2),
return_indices=True))
graphs.append(hg) graphs.append(hg)
nids.append(seeds) nids.append(seeds)
modes.append('link') modes.append('link')
...@@ -179,7 +206,7 @@ def test_neighbor_sampler_dataloader(): ...@@ -179,7 +206,7 @@ def test_neighbor_sampler_dataloader():
for _g, nid, collator, mode in zip(graphs, nids, collators, modes): for _g, nid, collator, mode in zip(graphs, nids, collators, modes):
dl = DataLoader( dl = DataLoader(
collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False) collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
_check_neighbor_sampling_dataloader(_g, nid, dl, mode) _check_neighbor_sampling_dataloader(_g, nid, dl, mode, collator)
def test_graph_dataloader(): def test_graph_dataloader():
batch_size = 16 batch_size = 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment