Unverified Commit d1827488 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 3e5137fe
"""Base classes and functionalities for dataloaders"""
from collections.abc import Mapping
import inspect
from ..base import NID, EID
from ..convert import heterograph
from collections.abc import Mapping
from .. import backend as F
from ..transforms import compact_graphs
from ..base import EID, NID
from ..convert import heterograph
from ..frame import LazyFeature
from ..utils import recursive_apply, context_of
from ..transforms import compact_graphs
from ..utils import context_of, recursive_apply
def _set_lazy_features(x, xdata, feature_names):
if feature_names is None:
......@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for type_, names in feature_names.items():
x[type_].data.update({k: LazyFeature(k) for k in names})
def set_node_lazy_features(g, feature_names):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
......@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.nodes, g.ndata, feature_names)
def set_edge_lazy_features(g, feature_names):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
......@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.edges, g.edata, feature_names)
def set_src_lazy_features(g, feature_names):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
......@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.srcnodes, g.srcdata, feature_names)
def set_dst_lazy_features(g, feature_names):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
......@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.dstnodes, g.dstdata, feature_names)
class Sampler(object):
"""Base class for graph samplers.
......@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices):
return g.subgraph(indices)
"""
def sample(self, g, indices):
"""Abstract sample method.
......@@ -183,6 +191,7 @@ class Sampler(object):
"""
raise NotImplementedError
class BlockSampler(Sampler):
"""Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs).
......@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
"""
def __init__(self, prefetch_node_feats=None, prefetch_labels=None,
prefetch_edge_feats=None, output_device=None):
def __init__(
self,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__()
self.prefetch_node_feats = prefetch_node_feats or []
self.prefetch_labels = prefetch_labels or []
......@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features(block, self.prefetch_edge_feats)
return input_nodes, output_nodes, blocks
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ
def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes."""
result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)
return self.assign_lazy_features(result)
......@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()}
for k, v in eids.items()
}
else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()}
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()})
for k, v in reverse_etype_map.items()
}
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
if exclude_mode is None:
return None
elif callable(exclude_mode):
return exclude_mode(eids)
elif F.is_tensor(exclude_mode) or (
isinstance(exclude_mode, Mapping) and
all(F.is_tensor(v) for v in exclude_mode.values())):
isinstance(exclude_mode, Mapping)
and all(F.is_tensor(v) for v in exclude_mode.values())
):
return exclude_mode
elif exclude_mode == 'self':
elif exclude_mode == "self":
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map'])
elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(
g, eids, kwargs["reverse_eid_map"]
)
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else:
raise ValueError('unsupported mode {}'.format(exclude_mode))
raise ValueError("unsupported mode {}".format(exclude_mode))
def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=None,
output_device=None):
def find_exclude_eids(
g,
seed_edges,
exclude,
reverse_eids=None,
reverse_etypes=None,
output_device=None,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters
......@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude,
seed_edges,
reverse_eid_map=reverse_eids,
reverse_etype_map=reverse_etypes)
reverse_etype_map=reverse_etypes,
)
if exclude_eids is not None and output_device is not None:
exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device))
exclude_eids = recursive_apply(
exclude_eids, lambda x: F.copy_to(x, output_device)
)
return exclude_eids
class EdgePredictionSampler(Sampler):
"""Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction.
......@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
--------
as_edge_prediction_sampler
"""
def __init__(self, sampler, exclude=None, reverse_eids=None,
reverse_etypes=None, negative_sampler=None, prefetch_labels=None):
def __init__(
self,
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
super().__init__()
# Check if the sampler's sample method has an optional third argument.
argspec = inspect.getfullargspec(sampler.sample)
if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids']
raise TypeError(
"This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method.")
"optional third argument for edge IDs to exclude in its sample() method."
)
self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes
self.exclude = exclude
......@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def _build_neg_graph(self, g, seed_edges):
neg_srcdst = self.negative_sampler(g, seed_edges)
if not isinstance(neg_srcdst, Mapping):
assert len(g.canonical_etypes) == 1, \
'graph has multiple or no edge types; '\
'please return a dict in negative sampler.'
assert len(g.canonical_etypes) == 1, (
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = context_of(seed_edges) if seed_edges is not None else g.device
neg_edges = {
etype: neg_srcdst.get(etype,
(F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx)))
for etype in g.canonical_etypes}
etype: neg_srcdst.get(
etype,
(
F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx),
),
)
for etype in g.canonical_etypes
}
neg_pair_graph = heterograph(
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes})
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
)
return neg_pair_graph
def assign_lazy_features(self, result):
......@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges.
"""
if isinstance(seed_edges, Mapping):
seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()}
seed_edges = {
g.to_canonical_etype(k): v for k, v in seed_edges.items()
}
exclude = self.exclude
pair_graph = g.edge_subgraph(
seed_edges, relabel_nodes=False, output_device=self.output_device)
seed_edges, relabel_nodes=False, output_device=self.output_device
)
eids = pair_graph.edata[EID]
if self.negative_sampler is not None:
......@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes = pair_graph.ndata[NID]
exclude_eids = find_exclude_eids(
g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes,
self.output_device)
g,
seed_edges,
exclude,
self.reverse_eids,
self.reverse_etypes,
self.output_device,
)
input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids)
input_nodes, _, blocks = self.sampler.sample(
g, seed_nodes, exclude_eids
)
if self.negative_sampler is None:
return self.assign_lazy_features((input_nodes, pair_graph, blocks))
else:
return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks))
return self.assign_lazy_features(
(input_nodes, pair_graph, neg_graph, blocks)
)
def as_edge_prediction_sampler(
sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None,
prefetch_labels=None):
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
"""Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to
......@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
"""
return EdgePredictionSampler(
sampler, exclude=exclude, reverse_eids=reverse_eids, reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler, prefetch_labels=prefetch_labels)
sampler,
exclude=exclude,
reverse_eids=reverse_eids,
reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler,
prefetch_labels=prefetch_labels,
)
"""Distributed dataloaders.
"""
import inspect
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Mapping
from abc import ABC, abstractproperty, abstractmethod
from .. import transforms
from ..base import NID, EID
from .. import backend as F
from .. import utils
from .. import backend as F, transforms, utils
from ..base import EID, NID
from ..convert import heterograph
from ..distributed import DistDataLoader
......@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()}
for k, v in eids.items()
}
else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()}
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()})
for k, v in reverse_etype_map.items()
}
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
......@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""
if exclude_mode is None:
return None
elif exclude_mode == 'self':
elif exclude_mode == "self":
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map'])
elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(
g, eids, kwargs["reverse_eid_map"]
)
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else:
raise ValueError('unsupported mode {}'.format(exclude_mode))
raise ValueError("unsupported mode {}".format(exclude_mode))
class Collator(ABC):
......@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
@abstractproperty
def dataset(self):
"""Returns the dataset object of the collator."""
......@@ -122,6 +132,7 @@ class Collator(ABC):
"""
raise NotImplementedError
class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling.
......@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, nids, graph_sampler):
self.g = g
if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types"
assert (
len(g.ntypes) == 1
), "nids should be a dict of node type and ids for graph with multiple node types"
self.graph_sampler = graph_sampler
self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids')
self.nids = utils.prepare_tensor_or_dict(g, nids, "nids")
self._dataset = utils.maybe_flatten_dict(self.nids)
@property
......@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g, items, 'items')
items = utils.prepare_tensor_or_dict(self.g, items, "items")
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items)
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(
self.g, items
)
return input_nodes, output_nodes, blocks
class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph
......@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None):
def __init__(
self,
g,
eids,
graph_sampler,
g_sampling=None,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
):
self.g = g
if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \
"eids should be a dict of etype and ids for graph with multiple etypes"
assert (
len(g.etypes) == 1
), "eids should be a dict of etype and ids for graph with multiple etypes"
self.graph_sampler = graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in
......@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self.reverse_etypes = reverse_etypes
self.negative_sampler = negative_sampler
self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids')
self.eids = utils.prepare_tensor_or_dict(g, eids, "eids")
self._dataset = utils.maybe_flatten_dict(self.eids)
@property
......@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID]
......@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, blocks
......@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items)
if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \
'graph has multiple or no edge types; '\
'please return a dict in negative sampler.'
assert len(self.g.etypes) == 1, (
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}
# Get dtype from a tuple of tensors
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = F.context(pair_graph)
neg_edges = {
etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx)))
for etype in self.g.canonical_etypes}
etype: neg_srcdst.get(
etype,
(
F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx),
),
)
for etype in self.g.canonical_etypes
}
neg_pair_graph = heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes})
neg_edges,
{ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes},
)
pair_graph, neg_pair_graph = transforms.compact_graphs([pair_graph, neg_pair_graph])
pair_graph, neg_pair_graph = transforms.compact_graphs(
[pair_graph, neg_pair_graph]
)
pair_graph.edata[EID] = induced_edges
seed_nodes = pair_graph.ndata[NID]
......@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, neg_pair_graph, blocks
......@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs:
del kwargs['num_workers']
if 'pin_memory' in kwargs:
del kwargs['pin_memory']
print('Distributed DataLoaders do not support pin_memory.')
if "num_workers" in kwargs:
del kwargs["num_workers"]
if "pin_memory" in kwargs:
del kwargs["pin_memory"]
print("Distributed DataLoaders do not support pin_memory.")
return kwargs
class DistNodeDataLoader(DistDataLoader):
"""Sampled graph data loader over nodes for distributed graph storage.
......@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def __init__(self, g, nids, graph_sampler, device=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
......@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs[k] = v
if device is None:
# for the distributed case default to the CPU
device = 'cpu'
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
device = "cpu"
assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset,
super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
**dataloader_kwargs
)
self.device = device
class DistEdgeDataLoader(DistDataLoader):
"""Sampled graph data loader over edges for distributed graph storage.
......@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
......@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if device is None:
# for the distributed case default to the CPU
device = 'cpu'
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
device = "cpu"
assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset,
super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
**dataloader_kwargs
)
self.device = device
......@@ -17,11 +17,11 @@
#
"""Data loading components for labor sampling"""
from ..base import NID, EID
from .. import backend as F
from ..base import EID, NID
from ..random import choice
from ..transforms import to_block
from .base import BlockSampler
from ..random import choice
from .. import backend as F
class LaborSampler(BlockSampler):
......@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
)
block.edata[EID] = eid
if len(g.canonical_etypes) > 1:
for etype, importance in zip(
g.canonical_etypes, importances
):
for etype, importance in zip(g.canonical_etypes, importances):
if importance.shape[0] == block.num_edges(etype):
block.edata["edge_weights"][etype] = importance
elif importances[0].shape[0] == block.num_edges():
......
"""Data loading components for neighbor sampling"""
from ..base import NID, EID
from ..base import EID, NID
from ..transforms import to_block
from .base import BlockSampler
class NeighborSampler(BlockSampler):
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
......@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, fanouts, edge_dir='in', prob=None, mask=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,
output_device=None):
super().__init__(prefetch_node_feats=prefetch_node_feats,
def __init__(
self,
fanouts,
edge_dir="in",
prob=None,
mask=None,
replace=False,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__(
prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device)
output_device=output_device,
)
self.fanouts = fanouts
self.edge_dir = edge_dir
if mask is not None and prob is not None:
raise ValueError(
'Mask and probability arguments are mutually exclusive. '
'Consider multiplying the probability with the mask '
'to achieve the same goal.')
"Mask and probability arguments are mutually exclusive. "
"Consider multiplying the probability with the mask "
"to achieve the same goal."
)
self.prob = prob or mask
self.replace = replace
......@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks = []
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob,
replace=self.replace, output_device=self.output_device,
exclude_edges=exclude_eids)
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
......@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return seed_nodes, output_nodes, blocks
MultiLayerNeighborSampler = NeighborSampler
class MultiLayerFullNeighborSampler(NeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
......@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, num_layers, **kwargs):
super().__init__([-1] * num_layers, **kwargs)
"""ShaDow-GNN subgraph samplers."""
from ..sampling.utils import EidExcluder
from .. import transforms
from ..base import NID
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
from ..sampling.utils import EidExcluder
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ShaDowKHopSampler(Sampler):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
......@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
"""
def __init__(self, fanouts, replace=False, prob=None, prefetch_node_feats=None,
prefetch_edge_feats=None, output_device=None):
def __init__(
self,
fanouts,
replace=False,
prob=None,
prefetch_node_feats=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__()
self.fanouts = fanouts
self.replace = replace
......@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self.prefetch_edge_feats = prefetch_edge_feats
self.output_device = output_device
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ
def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sampling function.
Parameters
......@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes = seed_nodes
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes, fanout, output_device=self.output_device,
replace=self.replace, prob=self.prob, exclude_edges=exclude_eids)
seed_nodes,
fanout,
output_device=self.output_device,
replace=self.replace,
prob=self.prob,
exclude_edges=exclude_eids,
)
block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device)
subg = g.subgraph(
seed_nodes, relabel_nodes=True, output_device=self.output_device
)
if exclude_eids is not None:
subg = EidExcluder(exclude_eids)(subg)
......
......@@ -7,8 +7,7 @@ from itertools import product
from .base import BuiltinFunction, TargetCode
__all__ = ["copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
__all__ = ["copy_u", "copy_e", "BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction):
......@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
--------
u_mul_e
"""
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op
self.lhs = lhs
......@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
--------
copy_u
"""
def __init__(self, target, in_field, out_field):
self.target = target
self.in_field = in_field
......@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
--------
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(binary_op,
""".format(
binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name)
name,
)
def func(lhs_field, rhs_field, out):
return BinaryMessageFunction(
binary_op, _TARGET_MAP[lhs],
_TARGET_MAP[rhs], lhs_field, rhs_field, out)
binary_op,
_TARGET_MAP[lhs],
_TARGET_MAP[rhs],
lhs_field,
rhs_field,
out,
)
func.__name__ = name
func.__doc__ = docstring
return func
......@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_builtin_message_func()
"""Module for various graph generator functions."""
from . import backend as F
from . import convert, random
from . import backend as F, convert, random
__all__ = ["rand_graph", "rand_bipartite"]
......
"""Python interfaces to DGL farthest point sampler."""
import numpy as np
from .. import backend as F
from .. import ndarray as nd
from .. import backend as F, ndarray as nd
from .._ffi.base import DGLError
from .._ffi.function import _init_api
......
......@@ -5,11 +5,10 @@ import networkx as nx
import numpy as np
import scipy
from . import backend as F
from . import utils
from . import backend as F, utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning
from .base import dgl_warning, DGLError
class BoolFlag(object):
......
This diff is collapsed.
......@@ -7,12 +7,11 @@ import sys
import numpy as np
import scipy
from . import backend as F
from . import utils
from . import backend as F, utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning
from .base import dgl_warning, DGLError
from .graph_index import from_coo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment