Unverified Commit d1827488 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 3e5137fe
"""Base classes and functionalities for dataloaders""" """Base classes and functionalities for dataloaders"""
from collections.abc import Mapping
import inspect import inspect
from ..base import NID, EID from collections.abc import Mapping
from ..convert import heterograph
from .. import backend as F from .. import backend as F
from ..transforms import compact_graphs from ..base import EID, NID
from ..convert import heterograph
from ..frame import LazyFeature from ..frame import LazyFeature
from ..utils import recursive_apply, context_of from ..transforms import compact_graphs
from ..utils import context_of, recursive_apply
def _set_lazy_features(x, xdata, feature_names): def _set_lazy_features(x, xdata, feature_names):
if feature_names is None: if feature_names is None:
...@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names): ...@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for type_, names in feature_names.items(): for type_, names in feature_names.items():
x[type_].data.update({k: LazyFeature(k) for k in names}) x[type_].data.update({k: LazyFeature(k) for k in names})
def set_node_lazy_features(g, feature_names): def set_node_lazy_features(g, feature_names):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization. """Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
...@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names): ...@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.nodes, g.ndata, feature_names) return _set_lazy_features(g.nodes, g.ndata, feature_names)
def set_edge_lazy_features(g, feature_names): def set_edge_lazy_features(g, feature_names):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization. """Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
...@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names): ...@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.edges, g.edata, feature_names) return _set_lazy_features(g.edges, g.edata, feature_names)
def set_src_lazy_features(g, feature_names): def set_src_lazy_features(g, feature_names):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization. """Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
...@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names): ...@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.srcnodes, g.srcdata, feature_names) return _set_lazy_features(g.srcnodes, g.srcdata, feature_names)
def set_dst_lazy_features(g, feature_names): def set_dst_lazy_features(g, feature_names):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization. """Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
...@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names): ...@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.dstnodes, g.dstdata, feature_names) return _set_lazy_features(g.dstnodes, g.dstdata, feature_names)
class Sampler(object): class Sampler(object):
"""Base class for graph samplers. """Base class for graph samplers.
...@@ -171,6 +178,7 @@ class Sampler(object): ...@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices): def sample(self, g, indices):
return g.subgraph(indices) return g.subgraph(indices)
""" """
def sample(self, g, indices): def sample(self, g, indices):
"""Abstract sample method. """Abstract sample method.
...@@ -183,6 +191,7 @@ class Sampler(object): ...@@ -183,6 +191,7 @@ class Sampler(object):
""" """
raise NotImplementedError raise NotImplementedError
class BlockSampler(Sampler): class BlockSampler(Sampler):
"""Base class for sampling mini-batches in the form of Message-passing """Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs). Flow Graphs (MFGs).
...@@ -211,8 +220,14 @@ class BlockSampler(Sampler): ...@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes. minibatch of seed nodes.
""" """
def __init__(self, prefetch_node_feats=None, prefetch_labels=None,
prefetch_edge_feats=None, output_device=None): def __init__(
self,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__() super().__init__()
self.prefetch_node_feats = prefetch_node_feats or [] self.prefetch_node_feats = prefetch_node_feats or []
self.prefetch_labels = prefetch_labels or [] self.prefetch_labels = prefetch_labels or []
...@@ -238,7 +253,9 @@ class BlockSampler(Sampler): ...@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features(block, self.prefetch_edge_feats) set_edge_lazy_features(block, self.prefetch_edge_feats)
return input_nodes, output_nodes, blocks return input_nodes, output_nodes, blocks
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes.""" """Sample a list of blocks from the given seed nodes."""
result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids) result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)
return self.assign_lazy_features(result) return self.assign_lazy_features(result)
...@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): ...@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = { exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()} for k, v in eids.items()
}
else: else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = { reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v) g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()} for k, v in reverse_etype_map.items()
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) }
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs): def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
if exclude_mode is None: if exclude_mode is None:
return None return None
elif callable(exclude_mode): elif callable(exclude_mode):
return exclude_mode(eids) return exclude_mode(eids)
elif F.is_tensor(exclude_mode) or ( elif F.is_tensor(exclude_mode) or (
isinstance(exclude_mode, Mapping) and isinstance(exclude_mode, Mapping)
all(F.is_tensor(v) for v in exclude_mode.values())): and all(F.is_tensor(v) for v in exclude_mode.values())
):
return exclude_mode return exclude_mode
elif exclude_mode == 'self': elif exclude_mode == "self":
return eids return eids
elif exclude_mode == 'reverse_id': elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) return _find_exclude_eids_with_reverse_id(
elif exclude_mode == 'reverse_types': g, eids, kwargs["reverse_eid_map"]
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) )
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else: else:
raise ValueError('unsupported mode {}'.format(exclude_mode)) raise ValueError("unsupported mode {}".format(exclude_mode))
def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=None,
output_device=None): def find_exclude_eids(
g,
seed_edges,
exclude,
reverse_eids=None,
reverse_etypes=None,
output_device=None,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`. """Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters Parameters
...@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= ...@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude, exclude,
seed_edges, seed_edges,
reverse_eid_map=reverse_eids, reverse_eid_map=reverse_eids,
reverse_etype_map=reverse_etypes) reverse_etype_map=reverse_etypes,
)
if exclude_eids is not None and output_device is not None: if exclude_eids is not None and output_device is not None:
exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device)) exclude_eids = recursive_apply(
exclude_eids, lambda x: F.copy_to(x, output_device)
)
return exclude_eids return exclude_eids
class EdgePredictionSampler(Sampler): class EdgePredictionSampler(Sampler):
"""Sampler class that wraps an existing sampler for node classification into another """Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction. one for edge classification or link prediction.
...@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler): ...@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
-------- --------
as_edge_prediction_sampler as_edge_prediction_sampler
""" """
def __init__(self, sampler, exclude=None, reverse_eids=None,
reverse_etypes=None, negative_sampler=None, prefetch_labels=None): def __init__(
self,
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
super().__init__() super().__init__()
# Check if the sampler's sample method has an optional third argument. # Check if the sampler's sample method has an optional third argument.
argspec = inspect.getfullargspec(sampler.sample) argspec = inspect.getfullargspec(sampler.sample)
if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids'] if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids']
raise TypeError( raise TypeError(
"This sampler does not support edge or link prediction; please add an" "This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method.") "optional third argument for edge IDs to exclude in its sample() method."
)
self.reverse_eids = reverse_eids self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes self.reverse_etypes = reverse_etypes
self.exclude = exclude self.exclude = exclude
...@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler): ...@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def _build_neg_graph(self, g, seed_edges): def _build_neg_graph(self, g, seed_edges):
neg_srcdst = self.negative_sampler(g, seed_edges) neg_srcdst = self.negative_sampler(g, seed_edges)
if not isinstance(neg_srcdst, Mapping): if not isinstance(neg_srcdst, Mapping):
assert len(g.canonical_etypes) == 1, \ assert len(g.canonical_etypes) == 1, (
'graph has multiple or no edge types; '\ "graph has multiple or no edge types; "
'please return a dict in negative sampler.' "please return a dict in negative sampler."
)
neg_srcdst = {g.canonical_etypes[0]: neg_srcdst} neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}
dtype = F.dtype(list(neg_srcdst.values())[0][0]) dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = context_of(seed_edges) if seed_edges is not None else g.device ctx = context_of(seed_edges) if seed_edges is not None else g.device
neg_edges = { neg_edges = {
etype: neg_srcdst.get(etype, etype: neg_srcdst.get(
(F.copy_to(F.tensor([], dtype), ctx=ctx), etype,
F.copy_to(F.tensor([], dtype), ctx=ctx))) (
for etype in g.canonical_etypes} F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx),
),
)
for etype in g.canonical_etypes
}
neg_pair_graph = heterograph( neg_pair_graph = heterograph(
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}) neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
)
return neg_pair_graph return neg_pair_graph
def assign_lazy_features(self, result): def assign_lazy_features(self, result):
...@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler): ...@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges. negative pairs as edges.
""" """
if isinstance(seed_edges, Mapping): if isinstance(seed_edges, Mapping):
seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()} seed_edges = {
g.to_canonical_etype(k): v for k, v in seed_edges.items()
}
exclude = self.exclude exclude = self.exclude
pair_graph = g.edge_subgraph( pair_graph = g.edge_subgraph(
seed_edges, relabel_nodes=False, output_device=self.output_device) seed_edges, relabel_nodes=False, output_device=self.output_device
)
eids = pair_graph.edata[EID] eids = pair_graph.edata[EID]
if self.negative_sampler is not None: if self.negative_sampler is not None:
...@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler): ...@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
exclude_eids = find_exclude_eids( exclude_eids = find_exclude_eids(
g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes, g,
self.output_device) seed_edges,
exclude,
self.reverse_eids,
self.reverse_etypes,
self.output_device,
)
input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids) input_nodes, _, blocks = self.sampler.sample(
g, seed_nodes, exclude_eids
)
if self.negative_sampler is None: if self.negative_sampler is None:
return self.assign_lazy_features((input_nodes, pair_graph, blocks)) return self.assign_lazy_features((input_nodes, pair_graph, blocks))
else: else:
return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks)) return self.assign_lazy_features(
(input_nodes, pair_graph, neg_graph, blocks)
)
def as_edge_prediction_sampler( def as_edge_prediction_sampler(
sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, sampler,
prefetch_labels=None): exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
"""Create an edge-wise sampler from a node-wise sampler. """Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to For each batch of edges, the sampler applies the provided node-wise sampler to
...@@ -571,5 +644,10 @@ def as_edge_prediction_sampler( ...@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
""" """
return EdgePredictionSampler( return EdgePredictionSampler(
sampler, exclude=exclude, reverse_eids=reverse_eids, reverse_etypes=reverse_etypes, sampler,
negative_sampler=negative_sampler, prefetch_labels=prefetch_labels) exclude=exclude,
reverse_eids=reverse_eids,
reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler,
prefetch_labels=prefetch_labels,
)
"""Distributed dataloaders. """Distributed dataloaders.
""" """
import inspect import inspect
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Mapping from collections.abc import Mapping
from abc import ABC, abstractproperty, abstractmethod
from .. import transforms from .. import backend as F, transforms, utils
from ..base import NID, EID from ..base import EID, NID
from .. import backend as F
from .. import utils
from ..convert import heterograph from ..convert import heterograph
from ..distributed import DistDataLoader from ..distributed import DistDataLoader
...@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): ...@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = { exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()} for k, v in eids.items()
}
else: else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = { reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v) g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()} for k, v in reverse_etype_map.items()
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) }
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs): def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`. """Find all edge IDs to exclude according to :attr:`exclude_mode`.
...@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs): ...@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
""" """
if exclude_mode is None: if exclude_mode is None:
return None return None
elif exclude_mode == 'self': elif exclude_mode == "self":
return eids return eids
elif exclude_mode == 'reverse_id': elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) return _find_exclude_eids_with_reverse_id(
elif exclude_mode == 'reverse_types': g, eids, kwargs["reverse_eid_map"]
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) )
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else: else:
raise ValueError('unsupported mode {}'.format(exclude_mode)) raise ValueError("unsupported mode {}".format(exclude_mode))
class Collator(ABC): class Collator(ABC):
...@@ -100,6 +109,7 @@ class Collator(ABC): ...@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
@abstractproperty @abstractproperty
def dataset(self): def dataset(self):
"""Returns the dataset object of the collator.""" """Returns the dataset object of the collator."""
...@@ -122,6 +132,7 @@ class Collator(ABC): ...@@ -122,6 +132,7 @@ class Collator(ABC):
""" """
raise NotImplementedError raise NotImplementedError
class NodeCollator(Collator): class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for """DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling. training node classification or regression on a single graph with neighborhood sampling.
...@@ -155,14 +166,16 @@ class NodeCollator(Collator): ...@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, g, nids, graph_sampler): def __init__(self, g, nids, graph_sampler):
self.g = g self.g = g
if not isinstance(nids, Mapping): if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \ assert (
"nids should be a dict of node type and ids for graph with multiple node types" len(g.ntypes) == 1
), "nids should be a dict of node type and ids for graph with multiple node types"
self.graph_sampler = graph_sampler self.graph_sampler = graph_sampler
self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids') self.nids = utils.prepare_tensor_or_dict(g, nids, "nids")
self._dataset = utils.maybe_flatten_dict(self.nids) self._dataset = utils.maybe_flatten_dict(self.nids)
@property @property
...@@ -197,12 +210,15 @@ class NodeCollator(Collator): ...@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g, items, 'items') items = utils.prepare_tensor_or_dict(self.g, items, "items")
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items) input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(
self.g, items
)
return input_nodes, output_nodes, blocks return input_nodes, output_nodes, blocks
class EdgeCollator(Collator): class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for """DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph training edge classification, edge regression, or link prediction on a single graph
...@@ -380,12 +396,23 @@ class EdgeCollator(Collator): ...@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None): def __init__(
self,
g,
eids,
graph_sampler,
g_sampling=None,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
):
self.g = g self.g = g
if not isinstance(eids, Mapping): if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \ assert (
"eids should be a dict of etype and ids for graph with multiple etypes" len(g.etypes) == 1
), "eids should be a dict of etype and ids for graph with multiple etypes"
self.graph_sampler = graph_sampler self.graph_sampler = graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in # One may wish to iterate over the edges in one graph while perform sampling in
...@@ -404,7 +431,7 @@ class EdgeCollator(Collator): ...@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self.reverse_etypes = reverse_etypes self.reverse_etypes = reverse_etypes
self.negative_sampler = negative_sampler self.negative_sampler = negative_sampler
self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids') self.eids = utils.prepare_tensor_or_dict(g, eids, "eids")
self._dataset = utils.maybe_flatten_dict(self.eids) self._dataset = utils.maybe_flatten_dict(self.eids)
@property @property
...@@ -415,7 +442,7 @@ class EdgeCollator(Collator): ...@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items) pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
...@@ -425,10 +452,12 @@ class EdgeCollator(Collator): ...@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self.exclude, self.exclude,
items, items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks( input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, blocks return input_nodes, pair_graph, blocks
...@@ -436,28 +465,39 @@ class EdgeCollator(Collator): ...@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False) pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID] induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items) neg_srcdst = self.negative_sampler(self.g, items)
if not isinstance(neg_srcdst, Mapping): if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \ assert len(self.g.etypes) == 1, (
'graph has multiple or no edge types; '\ "graph has multiple or no edge types; "
'please return a dict in negative sampler.' "please return a dict in negative sampler."
)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst} neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}
# Get dtype from a tuple of tensors # Get dtype from a tuple of tensors
dtype = F.dtype(list(neg_srcdst.values())[0][0]) dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = F.context(pair_graph) ctx = F.context(pair_graph)
neg_edges = { neg_edges = {
etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx), etype: neg_srcdst.get(
F.copy_to(F.tensor([], dtype), ctx))) etype,
for etype in self.g.canonical_etypes} (
F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx),
),
)
for etype in self.g.canonical_etypes
}
neg_pair_graph = heterograph( neg_pair_graph = heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes}) neg_edges,
{ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes},
)
pair_graph, neg_pair_graph = transforms.compact_graphs([pair_graph, neg_pair_graph]) pair_graph, neg_pair_graph = transforms.compact_graphs(
[pair_graph, neg_pair_graph]
)
pair_graph.edata[EID] = induced_edges pair_graph.edata[EID] = induced_edges
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
...@@ -467,10 +507,12 @@ class EdgeCollator(Collator): ...@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self.exclude, self.exclude,
items, items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks( input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, neg_pair_graph, blocks return input_nodes, pair_graph, neg_pair_graph, blocks
...@@ -517,13 +559,14 @@ class EdgeCollator(Collator): ...@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def _remove_kwargs_dist(kwargs): def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs: if "num_workers" in kwargs:
del kwargs['num_workers'] del kwargs["num_workers"]
if 'pin_memory' in kwargs: if "pin_memory" in kwargs:
del kwargs['pin_memory'] del kwargs["pin_memory"]
print('Distributed DataLoaders do not support pin_memory.') print("Distributed DataLoaders do not support pin_memory.")
return kwargs return kwargs
class DistNodeDataLoader(DistDataLoader): class DistNodeDataLoader(DistDataLoader):
"""Sampled graph data loader over nodes for distributed graph storage. """Sampled graph data loader over nodes for distributed graph storage.
...@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader): ...@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
-------- --------
dgl.dataloading.DataLoader dgl.dataloading.DataLoader
""" """
def __init__(self, g, nids, graph_sampler, device=None, **kwargs): def __init__(self, g, nids, graph_sampler, device=None, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
...@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader): ...@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
if device is None: if device is None:
# for the distributed case default to the CPU # for the distributed case default to the CPU
device = 'cpu' device = "cpu"
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs # Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution # and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs) self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs) _remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset, super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs
)
self.device = device self.device = device
class DistEdgeDataLoader(DistDataLoader): class DistEdgeDataLoader(DistDataLoader):
"""Sampled graph data loader over edges for distributed graph storage. """Sampled graph data loader over edges for distributed graph storage.
...@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader): ...@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
-------- --------
dgl.dataloading.DataLoader dgl.dataloading.DataLoader
""" """
def __init__(self, g, eids, graph_sampler, device=None, **kwargs): def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
...@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader): ...@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if device is None: if device is None:
# for the distributed case default to the CPU # for the distributed case default to the CPU
device = 'cpu' device = "cpu"
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs # Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution # and does not copy features. Fallback to normal solution
self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs) _remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset, super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs
)
self.device = device self.device = device
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
# #
"""Data loading components for labor sampling""" """Data loading components for labor sampling"""
from ..base import NID, EID from .. import backend as F
from ..base import EID, NID
from ..random import choice
from ..transforms import to_block from ..transforms import to_block
from .base import BlockSampler from .base import BlockSampler
from ..random import choice
from .. import backend as F
class LaborSampler(BlockSampler): class LaborSampler(BlockSampler):
...@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler): ...@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
) )
block.edata[EID] = eid block.edata[EID] = eid
if len(g.canonical_etypes) > 1: if len(g.canonical_etypes) > 1:
for etype, importance in zip( for etype, importance in zip(g.canonical_etypes, importances):
g.canonical_etypes, importances
):
if importance.shape[0] == block.num_edges(etype): if importance.shape[0] == block.num_edges(etype):
block.edata["edge_weights"][etype] = importance block.edata["edge_weights"][etype] = importance
elif importances[0].shape[0] == block.num_edges(): elif importances[0].shape[0] == block.num_edges():
......
"""Data loading components for neighbor sampling""" """Data loading components for neighbor sampling"""
from ..base import NID, EID from ..base import EID, NID
from ..transforms import to_block from ..transforms import to_block
from .base import BlockSampler from .base import BlockSampler
class NeighborSampler(BlockSampler): class NeighborSampler(BlockSampler):
"""Sampler that builds computational dependency of node representations via """Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN. neighbor sampling for multilayer GNN.
...@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler): ...@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, fanouts, edge_dir='in', prob=None, mask=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, def __init__(
output_device=None): self,
super().__init__(prefetch_node_feats=prefetch_node_feats, fanouts,
edge_dir="in",
prob=None,
mask=None,
replace=False,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__(
prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels, prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats, prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device) output_device=output_device,
)
self.fanouts = fanouts self.fanouts = fanouts
self.edge_dir = edge_dir self.edge_dir = edge_dir
if mask is not None and prob is not None: if mask is not None and prob is not None:
raise ValueError( raise ValueError(
'Mask and probability arguments are mutually exclusive. ' "Mask and probability arguments are mutually exclusive. "
'Consider multiplying the probability with the mask ' "Consider multiplying the probability with the mask "
'to achieve the same goal.') "to achieve the same goal."
)
self.prob = prob or mask self.prob = prob or mask
self.replace = replace self.replace = replace
...@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler): ...@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks = [] blocks = []
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors( frontier = g.sample_neighbors(
seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob, seed_nodes,
replace=self.replace, output_device=self.output_device, fanout,
exclude_edges=exclude_eids) edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID] eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes) block = to_block(frontier, seed_nodes)
block.edata[EID] = eid block.edata[EID] = eid
...@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler): ...@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return seed_nodes, output_nodes, blocks return seed_nodes, output_nodes, blocks
MultiLayerNeighborSampler = NeighborSampler MultiLayerNeighborSampler = NeighborSampler
class MultiLayerFullNeighborSampler(NeighborSampler): class MultiLayerFullNeighborSampler(NeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages """Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN. from all neighbors for multilayer GNN.
...@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler): ...@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, num_layers, **kwargs): def __init__(self, num_layers, **kwargs):
super().__init__([-1] * num_layers, **kwargs) super().__init__([-1] * num_layers, **kwargs)
"""ShaDow-GNN subgraph samplers.""" """ShaDow-GNN subgraph samplers."""
from ..sampling.utils import EidExcluder
from .. import transforms from .. import transforms
from ..base import NID from ..base import NID
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler from ..sampling.utils import EidExcluder
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ShaDowKHopSampler(Sampler): class ShaDowKHopSampler(Sampler):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow """K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
...@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler): ...@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p') >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
""" """
def __init__(self, fanouts, replace=False, prob=None, prefetch_node_feats=None,
prefetch_edge_feats=None, output_device=None): def __init__(
self,
fanouts,
replace=False,
prob=None,
prefetch_node_feats=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__() super().__init__()
self.fanouts = fanouts self.fanouts = fanouts
self.replace = replace self.replace = replace
...@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler): ...@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self.prefetch_edge_feats = prefetch_edge_feats self.prefetch_edge_feats = prefetch_edge_feats
self.output_device = output_device self.output_device = output_device
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sampling function. """Sampling function.
Parameters Parameters
...@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler): ...@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes = seed_nodes output_nodes = seed_nodes
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors( frontier = g.sample_neighbors(
seed_nodes, fanout, output_device=self.output_device, seed_nodes,
replace=self.replace, prob=self.prob, exclude_edges=exclude_eids) fanout,
output_device=self.output_device,
replace=self.replace,
prob=self.prob,
exclude_edges=exclude_eids,
)
block = transforms.to_block(frontier, seed_nodes) block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID] seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device) subg = g.subgraph(
seed_nodes, relabel_nodes=True, output_device=self.output_device
)
if exclude_eids is not None: if exclude_eids is not None:
subg = EidExcluder(exclude_eids)(subg) subg = EidExcluder(exclude_eids)(subg)
......
...@@ -7,8 +7,7 @@ from itertools import product ...@@ -7,8 +7,7 @@ from itertools import product
from .base import BuiltinFunction, TargetCode from .base import BuiltinFunction, TargetCode
__all__ = ["copy_u", "copy_e", __all__ = ["copy_u", "copy_e", "BinaryMessageFunction", "CopyMessageFunction"]
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction): class MessageFunction(BuiltinFunction):
...@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction): ...@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
-------- --------
u_mul_e u_mul_e
""" """
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field): def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op self.binary_op = binary_op
self.lhs = lhs self.lhs = lhs
...@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction): ...@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
-------- --------
copy_u copy_u
""" """
def __init__(self, target, in_field, out_field): def __init__(self, target, in_field, out_field):
self.target = target self.target = target
self.in_field = in_field self.in_field = in_field
...@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op): ...@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
-------- --------
>>> import dgl >>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm') >>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(binary_op, """.format(
binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]], TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]], TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]], TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]], TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name) name,
)
def func(lhs_field, rhs_field, out): def func(lhs_field, rhs_field, out):
return BinaryMessageFunction( return BinaryMessageFunction(
binary_op, _TARGET_MAP[lhs], binary_op,
_TARGET_MAP[rhs], lhs_field, rhs_field, out) _TARGET_MAP[lhs],
_TARGET_MAP[rhs],
lhs_field,
rhs_field,
out,
)
func.__name__ = name func.__name__ = name
func.__doc__ = docstring func.__doc__ = docstring
return func return func
...@@ -177,4 +186,5 @@ def _register_builtin_message_func(): ...@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr(sys.modules[__name__], func.__name__, func) setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__) __all__.append(func.__name__)
_register_builtin_message_func() _register_builtin_message_func()
"""Module for various graph generator functions.""" """Module for various graph generator functions."""
from . import backend as F from . import backend as F, convert, random
from . import convert, random
__all__ = ["rand_graph", "rand_bipartite"] __all__ = ["rand_graph", "rand_bipartite"]
......
"""Python interfaces to DGL farthest point sampler.""" """Python interfaces to DGL farthest point sampler."""
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, ndarray as nd
from .. import ndarray as nd
from .._ffi.base import DGLError from .._ffi.base import DGLError
from .._ffi.function import _init_api from .._ffi.function import _init_api
......
...@@ -5,11 +5,10 @@ import networkx as nx ...@@ -5,11 +5,10 @@ import networkx as nx
import numpy as np import numpy as np
import scipy import scipy
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning from .base import dgl_warning, DGLError
class BoolFlag(object): class BoolFlag(object):
......
"""Classes for heterogeneous graphs.""" """Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from collections import defaultdict
from collections.abc import Mapping, Iterable
from contextlib import contextmanager
import copy import copy
import numbers
import itertools import itertools
import numbers
# pylint: disable= too-many-lines
from collections import defaultdict
from collections.abc import Iterable, Mapping
from contextlib import contextmanager
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from . import backend as F, core, graph_index, heterograph_index, utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .ops import segment from .base import (
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning ALL,
from . import core dgl_warning,
from . import graph_index DGLError,
from . import heterograph_index EID,
from . import utils ETYPE,
from . import backend as F is_all,
NID,
NTYPE,
SLICE_FULL,
)
from .frame import Frame from .frame import Frame
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView from .ops import segment
from .view import (
HeteroEdgeDataView,
HeteroEdgeView,
HeteroNodeDataView,
HeteroNodeView,
)
__all__ = ["DGLGraph", "combine_names"]
__all__ = ['DGLGraph', 'combine_names']
class DGLGraph(object): class DGLGraph(object):
"""Class for storing graph structure and node/edge feature data. """Class for storing graph structure and node/edge feature data.
...@@ -35,16 +50,19 @@ class DGLGraph(object): ...@@ -35,16 +50,19 @@ class DGLGraph(object):
Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its
usage. usage.
""" """
is_block = False is_block = False
# pylint: disable=unused-argument, dangerous-default-value # pylint: disable=unused-argument, dangerous-default-value
def __init__(self, def __init__(
self,
gidx=[], gidx=[],
ntypes=['_N'], ntypes=["_N"],
etypes=['_E'], etypes=["_E"],
node_frames=None, node_frames=None,
edge_frames=None, edge_frames=None,
**deprecate_kwargs): **deprecate_kwargs
):
"""Internal constructor for creating a DGLGraph. """Internal constructor for creating a DGLGraph.
Parameters Parameters
...@@ -67,21 +85,42 @@ class DGLGraph(object): ...@@ -67,21 +85,42 @@ class DGLGraph(object):
of edge type i. (default: None) of edge type i. (default: None)
""" """
if isinstance(gidx, DGLGraph): if isinstance(gidx, DGLGraph):
raise DGLError('The input is already a DGLGraph. No need to create it again.') raise DGLError(
"The input is already a DGLGraph. No need to create it again."
)
if not isinstance(gidx, heterograph_index.HeteroGraphIndex): if not isinstance(gidx, heterograph_index.HeteroGraphIndex):
dgl_warning('Recommend creating graphs by `dgl.graph(data)`' dgl_warning(
' instead of `dgl.DGLGraph(data)`.') "Recommend creating graphs by `dgl.graph(data)`"
(sparse_fmt, arrays), num_src, num_dst = utils.graphdata2tensors(gidx) " instead of `dgl.DGLGraph(data)`."
if sparse_fmt == 'coo': )
(sparse_fmt, arrays), num_src, num_dst = utils.graphdata2tensors(
gidx
)
if sparse_fmt == "coo":
gidx = heterograph_index.create_unitgraph_from_coo( gidx = heterograph_index.create_unitgraph_from_coo(
1, num_src, num_dst, arrays[0], arrays[1], ['coo', 'csr', 'csc']) 1,
num_src,
num_dst,
arrays[0],
arrays[1],
["coo", "csr", "csc"],
)
else: else:
gidx = heterograph_index.create_unitgraph_from_csr( gidx = heterograph_index.create_unitgraph_from_csr(
1, num_src, num_dst, arrays[0], arrays[1], arrays[2], ['coo', 'csr', 'csc'], 1,
sparse_fmt == 'csc') num_src,
num_dst,
arrays[0],
arrays[1],
arrays[2],
["coo", "csr", "csc"],
sparse_fmt == "csc",
)
if len(deprecate_kwargs) != 0: if len(deprecate_kwargs) != 0:
dgl_warning('Keyword arguments {} are deprecated in v0.5, and can be safely' dgl_warning(
' removed in all cases.'.format(list(deprecate_kwargs.keys()))) "Keyword arguments {} are deprecated in v0.5, and can be safely"
" removed in all cases.".format(list(deprecate_kwargs.keys()))
)
self._init(gidx, ntypes, etypes, node_frames, edge_frames) self._init(gidx, ntypes, etypes, node_frames, edge_frames)
def _init(self, gidx, ntypes, etypes, node_frames, edge_frames): def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
...@@ -94,39 +133,51 @@ class DGLGraph(object): ...@@ -94,39 +133,51 @@ class DGLGraph(object):
# Handle node types # Handle node types
if isinstance(ntypes, tuple): if isinstance(ntypes, tuple):
if len(ntypes) != 2: if len(ntypes) != 2:
errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format( errmsg = "Invalid input. Expect a pair (srctypes, dsttypes) but got {}".format(
ntypes) ntypes
)
raise TypeError(errmsg) raise TypeError(errmsg)
if not self._graph.is_metagraph_unibipartite(): if not self._graph.is_metagraph_unibipartite():
raise ValueError('Invalid input. The metagraph must be a uni-directional' raise ValueError(
' bipartite graph.') "Invalid input. The metagraph must be a uni-directional"
" bipartite graph."
)
self._ntypes = ntypes[0] + ntypes[1] self._ntypes = ntypes[0] + ntypes[1]
self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])} self._srctypes_invmap = {t: i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])} self._dsttypes_invmap = {
t: i + len(ntypes[0]) for i, t in enumerate(ntypes[1])
}
self._is_unibipartite = True self._is_unibipartite = True
if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1: if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:
self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])] self._canonical_etypes = [
(ntypes[0][0], etypes[0], ntypes[1][0])
]
else: else:
self._ntypes = ntypes self._ntypes = ntypes
if len(ntypes) == 1: if len(ntypes) == 1:
src_dst_map = None src_dst_map = None
else: else:
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph) src_dst_map = find_src_dst_ntypes(
self._is_unibipartite = (src_dst_map is not None) self._ntypes, self._graph.metagraph
)
self._is_unibipartite = src_dst_map is not None
if self._is_unibipartite: if self._is_unibipartite:
self._srctypes_invmap, self._dsttypes_invmap = src_dst_map self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
else: else:
self._srctypes_invmap = {t : i for i, t in enumerate(self._ntypes)} self._srctypes_invmap = {
t: i for i, t in enumerate(self._ntypes)
}
self._dsttypes_invmap = self._srctypes_invmap self._dsttypes_invmap = self._srctypes_invmap
# Handle edge types # Handle edge types
self._etypes = etypes self._etypes = etypes
if self._canonical_etypes is None: if self._canonical_etypes is None:
if (len(etypes) == 1 and len(ntypes) == 1): if len(etypes) == 1 and len(ntypes) == 1:
self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])] self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])]
else: else:
self._canonical_etypes = make_canonical_etypes( self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph) self._etypes, self._ntypes, self._graph.metagraph
)
# An internal map from etype to canonical etype tuple. # An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate # If two etypes have the same name, an empty tuple is stored instead to indicate
...@@ -137,21 +188,29 @@ class DGLGraph(object): ...@@ -137,21 +188,29 @@ class DGLGraph(object):
self._etype2canonical[ety] = tuple() self._etype2canonical[ety] = tuple()
else: else:
self._etype2canonical[ety] = self._canonical_etypes[i] self._etype2canonical[ety] = self._canonical_etypes[i]
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)} self._etypes_invmap = {
t: i for i, t in enumerate(self._canonical_etypes)
}
# node and edge frame # node and edge frame
if node_frames is None: if node_frames is None:
node_frames = [None] * len(self._ntypes) node_frames = [None] * len(self._ntypes)
node_frames = [Frame(num_rows=self._graph.number_of_nodes(i)) node_frames = [
if frame is None else frame Frame(num_rows=self._graph.number_of_nodes(i))
for i, frame in enumerate(node_frames)] if frame is None
else frame
for i, frame in enumerate(node_frames)
]
self._node_frames = node_frames self._node_frames = node_frames
if edge_frames is None: if edge_frames is None:
edge_frames = [None] * len(self._etypes) edge_frames = [None] * len(self._etypes)
edge_frames = [Frame(num_rows=self._graph.number_of_edges(i)) edge_frames = [
if frame is None else frame Frame(num_rows=self._graph.number_of_edges(i))
for i, frame in enumerate(edge_frames)] if frame is None
else frame
for i, frame in enumerate(edge_frames)
]
self._edge_frames = edge_frames self._edge_frames = edge_frames
def __setstate__(self, state): def __setstate__(self, state):
...@@ -162,40 +221,60 @@ class DGLGraph(object): ...@@ -162,40 +221,60 @@ class DGLGraph(object):
self.__dict__.update(state) self.__dict__.update(state)
elif isinstance(state, tuple) and len(state) == 5: elif isinstance(state, tuple) and len(state) == 5:
# DGL == 0.4.3 # DGL == 0.4.3
dgl_warning("The object is pickled with DGL == 0.4.3. " dgl_warning(
"Some of the original attributes are ignored.") "The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored."
)
self._init(*state) self._init(*state)
elif isinstance(state, dict): elif isinstance(state, dict):
# DGL <= 0.4.2 # DGL <= 0.4.2
dgl_warning("The object is pickled with DGL <= 0.4.2. " dgl_warning(
"Some of the original attributes are ignored.") "The object is pickled with DGL <= 0.4.2. "
self._init(state['_graph'], state['_ntypes'], state['_etypes'], state['_node_frames'], "Some of the original attributes are ignored."
state['_edge_frames']) )
self._init(
state["_graph"],
state["_ntypes"],
state["_etypes"],
state["_node_frames"],
state["_edge_frames"],
)
else: else:
raise IOError("Unrecognized pickle format.") raise IOError("Unrecognized pickle format.")
def __repr__(self): def __repr__(self):
if len(self.ntypes) == 1 and len(self.etypes) == 1: if len(self.ntypes) == 1 and len(self.etypes) == 1:
ret = ('Graph(num_nodes={node}, num_edges={edge},\n' ret = (
' ndata_schemes={ndata}\n' "Graph(num_nodes={node}, num_edges={edge},\n"
' edata_schemes={edata})') " ndata_schemes={ndata}\n"
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(), " edata_schemes={edata})"
)
return ret.format(
node=self.number_of_nodes(),
edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()), ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes())) edata=str(self.edge_attr_schemes()),
)
else: else:
ret = ('Graph(num_nodes={node},\n' ret = (
' num_edges={edge},\n' "Graph(num_nodes={node},\n"
' metagraph={meta})') " num_edges={edge},\n"
nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i) " metagraph={meta})"
for i in range(len(self.ntypes))} )
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i) nnode_dict = {
for i in range(len(self.etypes))} self.ntypes[i]: self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))
}
nedge_dict = {
self.canonical_etypes[i]: self._graph.number_of_edges(i)
for i in range(len(self.etypes))
}
meta = str(self.metagraph().edges(keys=True)) meta = str(self.metagraph().edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta) return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
def __copy__(self): def __copy__(self):
"""Shallow copy implementation.""" """Shallow copy implementation."""
#TODO(minjie): too many states in python; should clean up and lower to C # TODO(minjie): too many states in python; should clean up and lower to C
cls = type(self) cls = type(self)
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.__dict__.update(self.__dict__) obj.__dict__.update(self.__dict__)
...@@ -298,14 +377,16 @@ class DGLGraph(object): ...@@ -298,14 +377,16 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support add_nodes # TODO(xiangsx): block do not support add_nodes
if ntype is None: if ntype is None:
if self._graph.number_of_ntypes() != 1: if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one ' raise DGLError(
'node types.') "Node type name must be specified if there are more than one "
"node types."
)
# nothing happen # nothing happen
if num == 0: if num == 0:
return return
assert num > 0, 'Number of new nodes should be larger than one.' assert num > 0, "Number of new nodes should be larger than one."
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
# update graph idx # update graph idx
metagraph = self._graph.metagraph metagraph = self._graph.metagraph
...@@ -319,23 +400,32 @@ class DGLGraph(object): ...@@ -319,23 +400,32 @@ class DGLGraph(object):
relation_graphs = [] relation_graphs = []
for c_etype in self.canonical_etypes: for c_etype in self.canonical_etypes:
# src or dst == ntype, update the relation graph # src or dst == ntype, update the relation graph
if self.get_ntype_id(c_etype[0]) == ntid or self.get_ntype_id(c_etype[2]) == ntid: if (
u, v = self.edges(form='uv', order='eid', etype=c_etype) self.get_ntype_id(c_etype[0]) == ntid
or self.get_ntype_id(c_etype[2]) == ntid
):
u, v = self.edges(form="uv", order="eid", etype=c_etype)
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
1 if c_etype[0] == c_etype[2] else 2, 1 if c_etype[0] == c_etype[2] else 2,
self.number_of_nodes(c_etype[0]) + \ self.number_of_nodes(c_etype[0])
(num if self.get_ntype_id(c_etype[0]) == ntid else 0), + (num if self.get_ntype_id(c_etype[0]) == ntid else 0),
self.number_of_nodes(c_etype[2]) + \ self.number_of_nodes(c_etype[2])
(num if self.get_ntype_id(c_etype[2]) == ntid else 0), + (num if self.get_ntype_id(c_etype[2]) == ntid else 0),
u, u,
v, v,
['coo', 'csr', 'csc']) ["coo", "csr", "csc"],
)
relation_graphs.append(hgidx) relation_graphs.append(hgidx)
else: else:
# do nothing # do nothing
relation_graphs.append(self._graph.get_relation_graph(self.get_etype_id(c_etype))) relation_graphs.append(
self._graph.get_relation_graph(self.get_etype_id(c_etype))
)
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64")) metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
self._graph = hgidx self._graph = hgidx
# update data frames # update data frames
...@@ -452,26 +542,33 @@ class DGLGraph(object): ...@@ -452,26 +542,33 @@ class DGLGraph(object):
remove_edges remove_edges
""" """
# TODO(xiangsx): block do not support add_edges # TODO(xiangsx): block do not support add_edges
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, "v")
if etype is None: if etype is None:
if self._graph.number_of_etypes() != 1: if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one ' raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
"edge types."
)
# nothing changed # nothing changed
if len(u) == 0 or len(v) == 0: if len(u) == 0 or len(v) == 0:
return return
assert len(u) == len(v) or len(u) == 1 or len(v) == 1, \ assert len(u) == len(v) or len(u) == 1 or len(v) == 1, (
'The number of source nodes and the number of destination nodes should be same, ' \ "The number of source nodes and the number of destination nodes should be same, "
'or either the number of source nodes or the number of destination nodes is 1.' "or either the number of source nodes or the number of destination nodes is 1."
)
if len(u) == 1 and len(v) > 1: if len(u) == 1 and len(v) > 1:
u = F.full_1d(len(v), F.as_scalar(u), dtype=F.dtype(u), ctx=F.context(u)) u = F.full_1d(
len(v), F.as_scalar(u), dtype=F.dtype(u), ctx=F.context(u)
)
if len(v) == 1 and len(u) > 1: if len(v) == 1 and len(u) > 1:
v = F.full_1d(len(u), F.as_scalar(v), dtype=F.dtype(v), ctx=F.context(v)) v = F.full_1d(
len(u), F.as_scalar(v), dtype=F.dtype(v), ctx=F.context(v)
)
u_type, e_type, v_type = self.to_canonical_etype(etype) u_type, e_type, v_type = self.to_canonical_etype(etype)
# if end nodes of adding edges does not exists # if end nodes of adding edges does not exists
...@@ -501,22 +598,28 @@ class DGLGraph(object): ...@@ -501,22 +598,28 @@ class DGLGraph(object):
for c_etype in self.canonical_etypes: for c_etype in self.canonical_etypes:
# the target edge type # the target edge type
if c_etype == (u_type, e_type, v_type): if c_etype == (u_type, e_type, v_type):
old_u, old_v = self.edges(form='uv', order='eid', etype=c_etype) old_u, old_v = self.edges(form="uv", order="eid", etype=c_etype)
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
1 if u_type == v_type else 2, 1 if u_type == v_type else 2,
self.number_of_nodes(u_type), self.number_of_nodes(u_type),
self.number_of_nodes(v_type), self.number_of_nodes(v_type),
F.cat([old_u, u], dim=0), F.cat([old_u, u], dim=0),
F.cat([old_v, v], dim=0), F.cat([old_v, v], dim=0),
['coo', 'csr', 'csc']) ["coo", "csr", "csc"],
)
relation_graphs.append(hgidx) relation_graphs.append(hgidx)
else: else:
# do nothing # do nothing
# Note: node range change has been handled in add_nodes() # Note: node range change has been handled in add_nodes()
relation_graphs.append(self._graph.get_relation_graph(self.get_etype_id(c_etype))) relation_graphs.append(
self._graph.get_relation_graph(self.get_etype_id(c_etype))
)
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64")) metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
self._graph = hgidx self._graph = hgidx
# handle data # handle data
...@@ -607,15 +710,19 @@ class DGLGraph(object): ...@@ -607,15 +710,19 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_edges # TODO(xiangsx): block do not support remove_edges
if etype is None: if etype is None:
if self._graph.number_of_etypes() != 1: if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one ' \ raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
eids = utils.prepare_tensor(self, eids, 'u') "edge types."
)
eids = utils.prepare_tensor(self, eids, "u")
if len(eids) == 0: if len(eids) == 0:
# no edge to delete # no edge to delete
return return
assert self.number_of_edges(etype) > F.as_scalar(F.max(eids, dim=0)), \ assert self.number_of_edges(etype) > F.as_scalar(
'The input eid {} is out of the range [0:{})'.format( F.max(eids, dim=0)
F.as_scalar(F.max(eids, dim=0)), self.number_of_edges(etype)) ), "The input eid {} is out of the range [0:{})".format(
F.as_scalar(F.max(eids, dim=0)), self.number_of_edges(etype)
)
# edge_subgraph # edge_subgraph
edges = {} edges = {}
...@@ -623,25 +730,36 @@ class DGLGraph(object): ...@@ -623,25 +730,36 @@ class DGLGraph(object):
for c_etype in self.canonical_etypes: for c_etype in self.canonical_etypes:
# the target edge type # the target edge type
if c_etype == (u_type, e_type, v_type): if c_etype == (u_type, e_type, v_type):
origin_eids = self.edges(form='eid', order='eid', etype=c_etype) origin_eids = self.edges(form="eid", order="eid", etype=c_etype)
edges[c_etype] = utils.compensate(eids, origin_eids) edges[c_etype] = utils.compensate(eids, origin_eids)
else: else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype) edges[c_etype] = self.edges(
form="eid", order="eid", etype=c_etype
)
# If the graph is batched, update batch_num_edges # If the graph is batched, update batch_num_edges
batched = self._batch_num_edges is not None batched = self._batch_num_edges is not None
if batched: if batched:
c_etype = (u_type, e_type, v_type) c_etype = (u_type, e_type, v_type)
one_hot_removed_edges = F.zeros((self.num_edges(c_etype),), F.float32, self.device) one_hot_removed_edges = F.zeros(
one_hot_removed_edges = F.scatter_row(one_hot_removed_edges, eids, (self.num_edges(c_etype),), F.float32, self.device
F.full_1d(len(eids), 1., F.float32, self.device)) )
one_hot_removed_edges = F.scatter_row(
one_hot_removed_edges,
eids,
F.full_1d(len(eids), 1.0, F.float32, self.device),
)
c_etype_batch_num_edges = self._batch_num_edges[c_etype] c_etype_batch_num_edges = self._batch_num_edges[c_etype]
batch_num_removed_edges = segment.segment_reduce(c_etype_batch_num_edges, batch_num_removed_edges = segment.segment_reduce(
one_hot_removed_edges, reducer='sum') c_etype_batch_num_edges, one_hot_removed_edges, reducer="sum"
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - \ )
F.astype(batch_num_removed_edges, F.int64) self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype(
batch_num_removed_edges, F.int64
sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids) )
sub_g = self.edge_subgraph(
edges, relabel_nodes=False, store_ids=store_ids
)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames self._edge_frames = sub_g._edge_frames
...@@ -733,16 +851,20 @@ class DGLGraph(object): ...@@ -733,16 +851,20 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_nodes # TODO(xiangsx): block do not support remove_nodes
if ntype is None: if ntype is None:
if self._graph.number_of_ntypes() != 1: if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one ' \ raise DGLError(
'node types.') "Node type name must be specified if there are more than one "
"node types."
)
nids = utils.prepare_tensor(self, nids, 'u') nids = utils.prepare_tensor(self, nids, "u")
if len(nids) == 0: if len(nids) == 0:
# no node to delete # no node to delete
return return
assert self.number_of_nodes(ntype) > F.as_scalar(F.max(nids, dim=0)), \ assert self.number_of_nodes(ntype) > F.as_scalar(
'The input nids {} is out of the range [0:{})'.format( F.max(nids, dim=0)
F.as_scalar(F.max(nids, dim=0)), self.number_of_nodes(ntype)) ), "The input nids {} is out of the range [0:{})".format(
F.as_scalar(F.max(nids, dim=0)), self.number_of_nodes(ntype)
)
ntid = self.get_ntype_id(ntype) ntid = self.get_ntype_id(ntype)
nodes = {} nodes = {}
...@@ -757,18 +879,28 @@ class DGLGraph(object): ...@@ -757,18 +879,28 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_nodes # If the graph is batched, update batch_num_nodes
batched = self._batch_num_nodes is not None batched = self._batch_num_nodes is not None
if batched: if batched:
one_hot_removed_nodes = F.zeros((self.num_nodes(target_ntype),), one_hot_removed_nodes = F.zeros(
F.float32, self.device) (self.num_nodes(target_ntype),), F.float32, self.device
one_hot_removed_nodes = F.scatter_row(one_hot_removed_nodes, nids, )
F.full_1d(len(nids), 1., F.float32, self.device)) one_hot_removed_nodes = F.scatter_row(
one_hot_removed_nodes,
nids,
F.full_1d(len(nids), 1.0, F.float32, self.device),
)
c_ntype_batch_num_nodes = self._batch_num_nodes[target_ntype] c_ntype_batch_num_nodes = self._batch_num_nodes[target_ntype]
batch_num_removed_nodes = segment.segment_reduce( batch_num_removed_nodes = segment.segment_reduce(
c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer='sum') c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer="sum"
self._batch_num_nodes[target_ntype] = c_ntype_batch_num_nodes - \ )
F.astype(batch_num_removed_nodes, F.int64) self._batch_num_nodes[
target_ntype
] = c_ntype_batch_num_nodes - F.astype(
batch_num_removed_nodes, F.int64
)
# Record old num_edges to check later whether some edges were removed # Record old num_edges to check later whether some edges were removed
old_num_edges = {c_etype: self._graph.number_of_edges(self.get_etype_id(c_etype)) old_num_edges = {
for c_etype in self.canonical_etypes} c_etype: self._graph.number_of_edges(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes
}
# node_subgraph # node_subgraph
# If batch_num_edges is to be updated, record the original edge IDs # If batch_num_edges is to be updated, record the original edge IDs
...@@ -780,22 +912,36 @@ class DGLGraph(object): ...@@ -780,22 +912,36 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_edges # If the graph is batched, update batch_num_edges
if batched: if batched:
canonical_etypes = [ canonical_etypes = [
c_etype for c_etype in self.canonical_etypes if c_etype
self._graph.number_of_edges(self.get_etype_id(c_etype)) != old_num_edges[c_etype]] for c_etype in self.canonical_etypes
if self._graph.number_of_edges(self.get_etype_id(c_etype))
!= old_num_edges[c_etype]
]
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if self._graph.number_of_edges(self.get_etype_id(c_etype)) == 0: if self._graph.number_of_edges(self.get_etype_id(c_etype)) == 0:
self._batch_num_edges[c_etype] = F.zeros( self._batch_num_edges[c_etype] = F.zeros(
(self.batch_size,), F.int64, self.device) (self.batch_size,), F.int64, self.device
)
continue continue
one_hot_left_edges = F.zeros((old_num_edges[c_etype],), F.float32, self.device) one_hot_left_edges = F.zeros(
(old_num_edges[c_etype],), F.float32, self.device
)
eids = self.edges[c_etype].data[EID] eids = self.edges[c_etype].data[EID]
one_hot_left_edges = F.scatter_row(one_hot_left_edges, eids, one_hot_left_edges = F.scatter_row(
F.full_1d(len(eids), 1., F.float32, self.device)) one_hot_left_edges,
eids,
F.full_1d(len(eids), 1.0, F.float32, self.device),
)
batch_num_left_edges = segment.segment_reduce( batch_num_left_edges = segment.segment_reduce(
self._batch_num_edges[c_etype], one_hot_left_edges, reducer='sum') self._batch_num_edges[c_etype],
self._batch_num_edges[c_etype] = F.astype(batch_num_left_edges, F.int64) one_hot_left_edges,
reducer="sum",
)
self._batch_num_edges[c_etype] = F.astype(
batch_num_left_edges, F.int64
)
if batched and not store_ids: if batched and not store_ids:
for c_ntype in self.ntypes: for c_ntype in self.ntypes:
...@@ -810,7 +956,6 @@ class DGLGraph(object): ...@@ -810,7 +956,6 @@ class DGLGraph(object):
self._batch_num_nodes = None self._batch_num_nodes = None
self._batch_num_edges = None self._batch_num_edges = None
################################################################# #################################################################
# Metagraph query # Metagraph query
################################################################# #################################################################
...@@ -1080,7 +1225,9 @@ class DGLGraph(object): ...@@ -1080,7 +1225,9 @@ class DGLGraph(object):
nx_graph = self._graph.metagraph.to_networkx() nx_graph = self._graph.metagraph.to_networkx()
nx_metagraph = nx.MultiDiGraph() nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges: for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']] srctype, etype, dsttype = self.canonical_etypes[
nx_graph.edges[u_v]["id"]
]
nx_metagraph.add_edge(srctype, dsttype, etype) nx_metagraph.add_edge(srctype, dsttype, etype)
return nx_metagraph return nx_metagraph
...@@ -1133,8 +1280,10 @@ class DGLGraph(object): ...@@ -1133,8 +1280,10 @@ class DGLGraph(object):
""" """
if etype is None: if etype is None:
if len(self.etypes) != 1: if len(self.etypes) != 1:
raise DGLError('Edge type name must be specified if there are more than one ' raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
"edge types."
)
etype = self.etypes[0] etype = self.etypes[0]
if isinstance(etype, tuple): if isinstance(etype, tuple):
return etype return etype
...@@ -1143,8 +1292,10 @@ class DGLGraph(object): ...@@ -1143,8 +1292,10 @@ class DGLGraph(object):
if ret is None: if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype)) raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0: if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical edge type ' raise DGLError(
'in the form of (srctype, etype, dsttype)' % etype) 'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)" % etype
)
return ret return ret
def get_ntype_id(self, ntype): def get_ntype_id(self, ntype):
...@@ -1164,19 +1315,23 @@ class DGLGraph(object): ...@@ -1164,19 +1315,23 @@ class DGLGraph(object):
""" """
if self.is_unibipartite and ntype is not None: if self.is_unibipartite and ntype is not None:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True. # Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if ntype.startswith('SRC/'): if ntype.startswith("SRC/"):
return self.get_ntype_id_from_src(ntype[4:]) return self.get_ntype_id_from_src(ntype[4:])
elif ntype.startswith('DST/'): elif ntype.startswith("DST/"):
return self.get_ntype_id_from_dst(ntype[4:]) return self.get_ntype_id_from_dst(ntype[4:])
# If there is no prefix, fallback to normal lookup. # If there is no prefix, fallback to normal lookup.
# Lookup both SRC and DST # Lookup both SRC and DST
if ntype is None: if ntype is None:
if self.is_unibipartite or len(self._srctypes_invmap) != 1: if self.is_unibipartite or len(self._srctypes_invmap) != 1:
raise DGLError('Node type name must be specified if there are more than one ' raise DGLError(
'node types.') "Node type name must be specified if there are more than one "
"node types."
)
return 0 return 0
ntid = self._srctypes_invmap.get(ntype, self._dsttypes_invmap.get(ntype, None)) ntid = self._srctypes_invmap.get(
ntype, self._dsttypes_invmap.get(ntype, None)
)
if ntid is None: if ntid is None:
raise DGLError('Node type "{}" does not exist.'.format(ntype)) raise DGLError('Node type "{}" does not exist.'.format(ntype))
return ntid return ntid
...@@ -1198,8 +1353,10 @@ class DGLGraph(object): ...@@ -1198,8 +1353,10 @@ class DGLGraph(object):
""" """
if ntype is None: if ntype is None:
if len(self._srctypes_invmap) != 1: if len(self._srctypes_invmap) != 1:
raise DGLError('SRC node type name must be specified if there are more than one ' raise DGLError(
'SRC node types.') "SRC node type name must be specified if there are more than one "
"SRC node types."
)
return next(iter(self._srctypes_invmap.values())) return next(iter(self._srctypes_invmap.values()))
ntid = self._srctypes_invmap.get(ntype, None) ntid = self._srctypes_invmap.get(ntype, None)
if ntid is None: if ntid is None:
...@@ -1223,8 +1380,10 @@ class DGLGraph(object): ...@@ -1223,8 +1380,10 @@ class DGLGraph(object):
""" """
if ntype is None: if ntype is None:
if len(self._dsttypes_invmap) != 1: if len(self._dsttypes_invmap) != 1:
raise DGLError('DST node type name must be specified if there are more than one ' raise DGLError(
'DST node types.') "DST node type name must be specified if there are more than one "
"DST node types."
)
return next(iter(self._dsttypes_invmap.values())) return next(iter(self._dsttypes_invmap.values()))
ntid = self._dsttypes_invmap.get(ntype, None) ntid = self._dsttypes_invmap.get(ntype, None)
if ntid is None: if ntid is None:
...@@ -1248,8 +1407,10 @@ class DGLGraph(object): ...@@ -1248,8 +1407,10 @@ class DGLGraph(object):
""" """
if etype is None: if etype is None:
if self._graph.number_of_etypes() != 1: if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one ' raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
"edge types."
)
return 0 return 0
etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None) etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
if etid is None: if etid is None:
...@@ -1346,17 +1507,23 @@ class DGLGraph(object): ...@@ -1346,17 +1507,23 @@ class DGLGraph(object):
tensor([2, 1]) tensor([2, 1])
""" """
if ntype is not None and ntype not in self.ntypes: if ntype is not None and ntype not in self.ntypes:
raise DGLError('Expect ntype in {}, got {}'.format(self.ntypes, ntype)) raise DGLError(
"Expect ntype in {}, got {}".format(self.ntypes, ntype)
)
if self._batch_num_nodes is None: if self._batch_num_nodes is None:
self._batch_num_nodes = {} self._batch_num_nodes = {}
for ty in self.ntypes: for ty in self.ntypes:
bnn = F.copy_to(F.tensor([self.number_of_nodes(ty)], F.int64), self.device) bnn = F.copy_to(
F.tensor([self.number_of_nodes(ty)], F.int64), self.device
)
self._batch_num_nodes[ty] = bnn self._batch_num_nodes[ty] = bnn
if ntype is None: if ntype is None:
if len(self.ntypes) != 1: if len(self.ntypes) != 1:
raise DGLError('Node type name must be specified if there are more than one ' raise DGLError(
'node types.') "Node type name must be specified if there are more than one "
"node types."
)
ntype = self.ntypes[0] ntype = self.ntypes[0]
return self._batch_num_nodes[ntype] return self._batch_num_nodes[ntype]
...@@ -1440,8 +1607,10 @@ class DGLGraph(object): ...@@ -1440,8 +1607,10 @@ class DGLGraph(object):
""" """
if not isinstance(val, Mapping): if not isinstance(val, Mapping):
if len(self.ntypes) != 1: if len(self.ntypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple node types.') raise DGLError(
val = {self.ntypes[0] : val} "Must provide a dictionary when there are multiple node types."
)
val = {self.ntypes[0]: val}
self._batch_num_nodes = val self._batch_num_nodes = val
def batch_num_edges(self, etype=None): def batch_num_edges(self, etype=None):
...@@ -1494,12 +1663,16 @@ class DGLGraph(object): ...@@ -1494,12 +1663,16 @@ class DGLGraph(object):
if self._batch_num_edges is None: if self._batch_num_edges is None:
self._batch_num_edges = {} self._batch_num_edges = {}
for ty in self.canonical_etypes: for ty in self.canonical_etypes:
bne = F.copy_to(F.tensor([self.number_of_edges(ty)], F.int64), self.device) bne = F.copy_to(
F.tensor([self.number_of_edges(ty)], F.int64), self.device
)
self._batch_num_edges[ty] = bne self._batch_num_edges[ty] = bne
if etype is None: if etype is None:
if len(self.etypes) != 1: if len(self.etypes) != 1:
raise DGLError('Edge type name must be specified if there are more than one ' raise DGLError(
'edge types.') "Edge type name must be specified if there are more than one "
"edge types."
)
etype = self.canonical_etypes[0] etype = self.canonical_etypes[0]
else: else:
etype = self.to_canonical_etype(etype) etype = self.to_canonical_etype(etype)
...@@ -1585,8 +1758,10 @@ class DGLGraph(object): ...@@ -1585,8 +1758,10 @@ class DGLGraph(object):
""" """
if not isinstance(val, Mapping): if not isinstance(val, Mapping):
if len(self.etypes) != 1: if len(self.etypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple edge types.') raise DGLError(
val = {self.canonical_etypes[0] : val} "Must provide a dictionary when there are multiple edge types."
)
val = {self.canonical_etypes[0]: val}
self._batch_num_edges = val self._batch_num_edges = val
################################################################# #################################################################
...@@ -2130,10 +2305,14 @@ class DGLGraph(object): ...@@ -2130,10 +2305,14 @@ class DGLGraph(object):
def _find_etypes(self, key): def _find_etypes(self, key):
etypes = [ etypes = [
i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if i
(key[0] == SLICE_FULL or key[0] == srctype) and for i, (srctype, etype, dsttype) in enumerate(
(key[1] == SLICE_FULL or key[1] == etype) and self._canonical_etypes
(key[2] == SLICE_FULL or key[2] == dsttype)] )
if (key[0] == SLICE_FULL or key[0] == srctype)
and (key[1] == SLICE_FULL or key[1] == etype)
and (key[2] == SLICE_FULL or key[2] == dsttype)
]
return etypes return etypes
def __getitem__(self, key): def __getitem__(self, key):
...@@ -2215,9 +2394,11 @@ class DGLGraph(object): ...@@ -2215,9 +2394,11 @@ class DGLGraph(object):
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE] >>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
""" """
err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\ err_msg = (
"to get view of one relation type. Use : to slice multiple types (e.g. " +\ "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
"G['srctype', :, 'dsttype'])." + "to get view of one relation type. Use : to slice multiple types (e.g. "
+ "G['srctype', :, 'dsttype'])."
)
orig_key = key orig_key = key
if not isinstance(key, tuple): if not isinstance(key, tuple):
...@@ -2229,7 +2410,11 @@ class DGLGraph(object): ...@@ -2229,7 +2410,11 @@ class DGLGraph(object):
etypes = self._find_etypes(key) etypes = self._find_etypes(key)
if len(etypes) == 0: if len(etypes) == 0:
raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key)) raise DGLError(
'Invalid key "{}". Must be one of the edge types.'.format(
orig_key
)
)
if len(etypes) == 1: if len(etypes) == 1:
# no ambiguity: return the unitgraph itself # no ambiguity: return the unitgraph itself
...@@ -2248,7 +2433,9 @@ class DGLGraph(object): ...@@ -2248,7 +2433,9 @@ class DGLGraph(object):
new_etypes = [etype] new_etypes = [etype]
new_eframes = [self._edge_frames[etid]] new_eframes = [self._edge_frames[etid]]
return self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes) return self.__class__(
new_g, new_ntypes, new_etypes, new_nframes, new_eframes
)
else: else:
flat = self._graph.flatten_relations(etypes) flat = self._graph.flatten_relations(etypes)
new_g = flat.graph new_g = flat.graph
...@@ -2262,7 +2449,8 @@ class DGLGraph(object): ...@@ -2262,7 +2449,8 @@ class DGLGraph(object):
new_ntypes.append(combine_names(self.ntypes, dtids)) new_ntypes.append(combine_names(self.ntypes, dtids))
new_nframes = [ new_nframes = [
combine_frames(self._node_frames, stids), combine_frames(self._node_frames, stids),
combine_frames(self._node_frames, dtids)] combine_frames(self._node_frames, dtids),
]
else: else:
assert np.array_equal(stids, dtids) assert np.array_equal(stids, dtids)
new_nframes = [combine_frames(self._node_frames, stids)] new_nframes = [combine_frames(self._node_frames, stids)]
...@@ -2270,16 +2458,28 @@ class DGLGraph(object): ...@@ -2270,16 +2458,28 @@ class DGLGraph(object):
new_eframes = [combine_frames(self._edge_frames, etids)] new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph # create new heterograph
new_hg = self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes) new_hg = self.__class__(
new_g, new_ntypes, new_etypes, new_nframes, new_eframes
)
src = new_ntypes[0] src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
# put the parent node/edge type and IDs # put the parent node/edge type and IDs
new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype) new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(
new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid) flat.induced_srctype
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype) )
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid) new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype) flat.induced_srcid
)
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(
flat.induced_dsttype
)
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(
flat.induced_dstid
)
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(
flat.induced_etype
)
new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid) new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)
return new_hg return new_hg
...@@ -2331,7 +2531,12 @@ class DGLGraph(object): ...@@ -2331,7 +2531,12 @@ class DGLGraph(object):
12 12
""" """
if ntype is None: if ntype is None:
return sum([self._graph.number_of_nodes(ntid) for ntid in range(len(self.ntypes))]) return sum(
[
self._graph.number_of_nodes(ntid)
for ntid in range(len(self.ntypes))
]
)
else: else:
return self._graph.number_of_nodes(self.get_ntype_id(ntype)) return self._graph.number_of_nodes(self.get_ntype_id(ntype))
...@@ -2396,10 +2601,16 @@ class DGLGraph(object): ...@@ -2396,10 +2601,16 @@ class DGLGraph(object):
7 7
""" """
if ntype is None: if ntype is None:
return sum([self._graph.number_of_nodes(self.get_ntype_id_from_src(nty)) return sum(
for nty in self.srctypes]) [
self._graph.number_of_nodes(self.get_ntype_id_from_src(nty))
for nty in self.srctypes
]
)
else: else:
return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype)) return self._graph.number_of_nodes(
self.get_ntype_id_from_src(ntype)
)
def number_of_dst_nodes(self, ntype=None): def number_of_dst_nodes(self, ntype=None):
"""Alias of :func:`num_dst_nodes`""" """Alias of :func:`num_dst_nodes`"""
...@@ -2462,10 +2673,16 @@ class DGLGraph(object): ...@@ -2462,10 +2673,16 @@ class DGLGraph(object):
12 12
""" """
if ntype is None: if ntype is None:
return sum([self._graph.number_of_nodes(self.get_ntype_id_from_dst(nty)) return sum(
for nty in self.dsttypes]) [
self._graph.number_of_nodes(self.get_ntype_id_from_dst(nty))
for nty in self.dsttypes
]
)
else: else:
return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype)) return self._graph.number_of_nodes(
self.get_ntype_id_from_dst(ntype)
)
def number_of_edges(self, etype=None): def number_of_edges(self, etype=None):
"""Alias of :func:`num_edges`""" """Alias of :func:`num_edges`"""
...@@ -2522,8 +2739,12 @@ class DGLGraph(object): ...@@ -2522,8 +2739,12 @@ class DGLGraph(object):
3 3
""" """
if etype is None: if etype is None:
return sum([self._graph.number_of_edges(etid) return sum(
for etid in range(len(self.canonical_etypes))]) [
self._graph.number_of_edges(etid)
for etid in range(len(self.canonical_etypes))
]
)
else: else:
return self._graph.number_of_edges(self.get_etype_id(etype)) return self._graph.number_of_edges(self.get_etype_id(etype))
...@@ -2708,10 +2929,11 @@ class DGLGraph(object): ...@@ -2708,10 +2929,11 @@ class DGLGraph(object):
tensor([False, True, True]) tensor([False, True, True])
""" """
vid_tensor = utils.prepare_tensor(self, vid, "vid") vid_tensor = utils.prepare_tensor(self, vid, "vid")
if len(vid_tensor) > 0 and F.as_scalar(F.min(vid_tensor, 0)) < 0 < len(vid_tensor): if len(vid_tensor) > 0 and F.as_scalar(F.min(vid_tensor, 0)) < 0 < len(
raise DGLError('All IDs must be non-negative integers.') vid_tensor
ret = self._graph.has_nodes( ):
self.get_ntype_id(ntype), vid_tensor) raise DGLError("All IDs must be non-negative integers.")
ret = self._graph.has_nodes(self.get_ntype_id(ntype), vid_tensor)
if isinstance(vid, numbers.Integral): if isinstance(vid, numbers.Integral):
return bool(F.as_scalar(ret)) return bool(F.as_scalar(ret))
else: else:
...@@ -2793,15 +3015,19 @@ class DGLGraph(object): ...@@ -2793,15 +3015,19 @@ class DGLGraph(object):
tensor([True, True]) tensor([True, True])
""" """
srctype, _, dsttype = self.to_canonical_etype(etype) srctype, _, dsttype = self.to_canonical_etype(etype)
u_tensor = utils.prepare_tensor(self, u, 'u') u_tensor = utils.prepare_tensor(self, u, "u")
if F.as_scalar(F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)) != len(u_tensor): if F.as_scalar(
raise DGLError('u contains invalid node IDs') F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)
v_tensor = utils.prepare_tensor(self, v, 'v') ) != len(u_tensor):
if F.as_scalar(F.sum(self.has_nodes(v_tensor, ntype=dsttype), dim=0)) != len(v_tensor): raise DGLError("u contains invalid node IDs")
raise DGLError('v contains invalid node IDs') v_tensor = utils.prepare_tensor(self, v, "v")
if F.as_scalar(
F.sum(self.has_nodes(v_tensor, ntype=dsttype), dim=0)
) != len(v_tensor):
raise DGLError("v contains invalid node IDs")
ret = self._graph.has_edges_between( ret = self._graph.has_edges_between(
self.get_etype_id(etype), self.get_etype_id(etype), u_tensor, v_tensor
u_tensor, v_tensor) )
if isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral): if isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral):
return bool(F.as_scalar(ret)) return bool(F.as_scalar(ret))
else: else:
...@@ -2863,7 +3089,7 @@ class DGLGraph(object): ...@@ -2863,7 +3089,7 @@ class DGLGraph(object):
successors successors
""" """
if not self.has_nodes(v, self.to_canonical_etype(etype)[-1]): if not self.has_nodes(v, self.to_canonical_etype(etype)[-1]):
raise DGLError('Non-existing node ID {}'.format(v)) raise DGLError("Non-existing node ID {}".format(v))
return self._graph.predecessors(self.get_etype_id(etype), v) return self._graph.predecessors(self.get_etype_id(etype), v)
def successors(self, v, etype=None): def successors(self, v, etype=None):
...@@ -2921,7 +3147,7 @@ class DGLGraph(object): ...@@ -2921,7 +3147,7 @@ class DGLGraph(object):
predecessors predecessors
""" """
if not self.has_nodes(v, self.to_canonical_etype(etype)[0]): if not self.has_nodes(v, self.to_canonical_etype(etype)[0]):
raise DGLError('Non-existing node ID {}'.format(v)) raise DGLError("Non-existing node ID {}".format(v))
return self._graph.successors(self.get_etype_id(etype), v) return self._graph.successors(self.get_etype_id(etype), v)
def edge_ids(self, u, v, return_uv=False, etype=None): def edge_ids(self, u, v, return_uv=False, etype=None):
...@@ -3018,14 +3244,20 @@ class DGLGraph(object): ...@@ -3018,14 +3244,20 @@ class DGLGraph(object):
... etype=('user', 'follows', 'game')) ... etype=('user', 'follows', 'game'))
tensor([1, 2]) tensor([1, 2])
""" """
is_int = isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral) is_int = isinstance(u, numbers.Integral) and isinstance(
v, numbers.Integral
)
srctype, _, dsttype = self.to_canonical_etype(etype) srctype, _, dsttype = self.to_canonical_etype(etype)
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u): if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(
raise DGLError('u contains invalid node IDs') u
v = utils.prepare_tensor(self, v, 'v') ):
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v): raise DGLError("u contains invalid node IDs")
raise DGLError('v contains invalid node IDs') v = utils.prepare_tensor(self, v, "v")
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(
v
):
raise DGLError("v contains invalid node IDs")
if return_uv: if return_uv:
return self._graph.edge_ids_all(self.get_etype_id(etype), u, v) return self._graph.edge_ids_all(self.get_etype_id(etype), u, v)
...@@ -3035,9 +3267,13 @@ class DGLGraph(object): ...@@ -3035,9 +3267,13 @@ class DGLGraph(object):
if F.as_scalar(F.sum(is_neg_one, 0)): if F.as_scalar(F.sum(is_neg_one, 0)):
# Raise error since some (u, v) pair is not a valid edge. # Raise error since some (u, v) pair is not a valid edge.
idx = F.nonzero_1d(is_neg_one) idx = F.nonzero_1d(is_neg_one)
raise DGLError("Error: (%d, %d) does not form a valid edge." % ( raise DGLError(
"Error: (%d, %d) does not form a valid edge."
% (
F.as_scalar(F.gather_row(u, idx)), F.as_scalar(F.gather_row(u, idx)),
F.as_scalar(F.gather_row(v, idx)))) F.as_scalar(F.gather_row(v, idx)),
)
)
return F.as_scalar(eid) if is_int else eid return F.as_scalar(eid) if is_int else eid
def find_edges(self, eid, etype=None): def find_edges(self, eid, etype=None):
...@@ -3096,14 +3332,14 @@ class DGLGraph(object): ...@@ -3096,14 +3332,14 @@ class DGLGraph(object):
>>> hg.find_edges(torch.tensor([1, 0]), 'plays') >>> hg.find_edges(torch.tensor([1, 0]), 'plays')
(tensor([4, 3]), tensor([6, 5])) (tensor([4, 3]), tensor([6, 5]))
""" """
eid = utils.prepare_tensor(self, eid, 'eid') eid = utils.prepare_tensor(self, eid, "eid")
if len(eid) > 0: if len(eid) > 0:
min_eid = F.as_scalar(F.min(eid, 0)) min_eid = F.as_scalar(F.min(eid, 0))
if min_eid < 0: if min_eid < 0:
raise DGLError('Invalid edge ID {:d}'.format(min_eid)) raise DGLError("Invalid edge ID {:d}".format(min_eid))
max_eid = F.as_scalar(F.max(eid, 0)) max_eid = F.as_scalar(F.max(eid, 0))
if max_eid >= self.num_edges(etype): if max_eid >= self.num_edges(etype):
raise DGLError('Invalid edge ID {:d}'.format(max_eid)) raise DGLError("Invalid edge ID {:d}".format(max_eid))
if len(eid) == 0: if len(eid) == 0:
empty = F.copy_to(F.tensor([], self.idtype), self.device) empty = F.copy_to(F.tensor([], self.idtype), self.device)
...@@ -3111,7 +3347,7 @@ class DGLGraph(object): ...@@ -3111,7 +3347,7 @@ class DGLGraph(object):
src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid) src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src, dst return src, dst
def in_edges(self, v, form='uv', etype=None): def in_edges(self, v, form="uv", etype=None):
"""Return the incoming edges of the given nodes. """Return the incoming edges of the given nodes.
Parameters Parameters
...@@ -3184,18 +3420,20 @@ class DGLGraph(object): ...@@ -3184,18 +3420,20 @@ class DGLGraph(object):
edges edges
out_edges out_edges
""" """
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, "v")
src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v) src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
if form == 'all': if form == "all":
return src, dst, eid return src, dst, eid
elif form == 'uv': elif form == "uv":
return src, dst return src, dst
elif form == 'eid': elif form == "eid":
return eid return eid
else: else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)) raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def out_edges(self, u, form='uv', etype=None): def out_edges(self, u, form="uv", etype=None):
"""Return the outgoing edges of the given nodes. """Return the outgoing edges of the given nodes.
Parameters Parameters
...@@ -3268,21 +3506,25 @@ class DGLGraph(object): ...@@ -3268,21 +3506,25 @@ class DGLGraph(object):
edges edges
in_edges in_edges
""" """
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
srctype, _, _ = self.to_canonical_etype(etype) srctype, _, _ = self.to_canonical_etype(etype)
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u): if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(
raise DGLError('u contains invalid node IDs') u
):
raise DGLError("u contains invalid node IDs")
src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u) src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
if form == 'all': if form == "all":
return src, dst, eid return src, dst, eid
elif form == 'uv': elif form == "uv":
return src, dst return src, dst
elif form == 'eid': elif form == "eid":
return eid return eid
else: else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)) raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def all_edges(self, form='uv', order='eid', etype=None): def all_edges(self, form="uv", order="eid", etype=None):
"""Return all edges with the specified edge type. """Return all edges with the specified edge type.
Parameters Parameters
...@@ -3353,14 +3595,16 @@ class DGLGraph(object): ...@@ -3353,14 +3595,16 @@ class DGLGraph(object):
out_edges out_edges
""" """
src, dst, eid = self._graph.edges(self.get_etype_id(etype), order) src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
if form == 'all': if form == "all":
return src, dst, eid return src, dst, eid
elif form == 'uv': elif form == "uv":
return src, dst return src, dst
elif form == 'eid': elif form == "eid":
return eid return eid
else: else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)) raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def in_degrees(self, v=ALL, etype=None): def in_degrees(self, v=ALL, etype=None):
"""Return the in-degree(s) of the given nodes. """Return the in-degree(s) of the given nodes.
...@@ -3431,7 +3675,7 @@ class DGLGraph(object): ...@@ -3431,7 +3675,7 @@ class DGLGraph(object):
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
if is_all(v): if is_all(v):
v = self.dstnodes(dsttype) v = self.dstnodes(dsttype)
v_tensor = utils.prepare_tensor(self, v, 'v') v_tensor = utils.prepare_tensor(self, v, "v")
deg = self._graph.in_degrees(etid, v_tensor) deg = self._graph.in_degrees(etid, v_tensor)
if isinstance(v, numbers.Integral): if isinstance(v, numbers.Integral):
return F.as_scalar(deg) return F.as_scalar(deg)
...@@ -3507,16 +3751,20 @@ class DGLGraph(object): ...@@ -3507,16 +3751,20 @@ class DGLGraph(object):
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
if is_all(u): if is_all(u):
u = self.srcnodes(srctype) u = self.srcnodes(srctype)
u_tensor = utils.prepare_tensor(self, u, 'u') u_tensor = utils.prepare_tensor(self, u, "u")
if F.as_scalar(F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)) != len(u_tensor): if F.as_scalar(
raise DGLError('u contains invalid node IDs') F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)
deg = self._graph.out_degrees(etid, utils.prepare_tensor(self, u, 'u')) ) != len(u_tensor):
raise DGLError("u contains invalid node IDs")
deg = self._graph.out_degrees(etid, utils.prepare_tensor(self, u, "u"))
if isinstance(u, numbers.Integral): if isinstance(u, numbers.Integral):
return F.as_scalar(deg) return F.as_scalar(deg)
else: else:
return deg return deg
def adjacency_matrix(self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None): def adjacency_matrix(
self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None
):
"""Alias of :meth:`adj`""" """Alias of :meth:`adj`"""
return self.adj(transpose, ctx, scipy_fmt, etype) return self.adj(transpose, ctx, scipy_fmt, etype)
...@@ -3586,7 +3834,9 @@ class DGLGraph(object): ...@@ -3586,7 +3834,9 @@ class DGLGraph(object):
if scipy_fmt is None: if scipy_fmt is None:
return self._graph.adjacency_matrix(etid, transpose, ctx)[0] return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
else: else:
return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False) return self._graph.adjacency_matrix_scipy(
etid, transpose, scipy_fmt, False
)
def adj_sparse(self, fmt, etype=None): def adj_sparse(self, fmt, etype=None):
"""Return the adjacency matrix of edges of the given edge type as tensors of """Return the adjacency matrix of edges of the given edge type as tensors of
...@@ -3629,9 +3879,9 @@ class DGLGraph(object): ...@@ -3629,9 +3879,9 @@ class DGLGraph(object):
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2])) (tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))
""" """
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
if fmt == 'csc': if fmt == "csc":
# The first two elements are number of rows and columns # The first two elements are number of rows and columns
return self._graph.adjacency_matrix_tensors(etid, True, 'csr')[2:] return self._graph.adjacency_matrix_tensors(etid, True, "csr")[2:]
else: else:
return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:] return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:]
...@@ -4024,26 +4274,36 @@ class DGLGraph(object): ...@@ -4024,26 +4274,36 @@ class DGLGraph(object):
if is_all(u): if is_all(u):
num_nodes = self._graph.number_of_nodes(ntid) num_nodes = self._graph.number_of_nodes(ntid)
else: else:
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
num_nodes = len(u) num_nodes = len(u)
for key, val in data.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
if nfeats != num_nodes: if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).' raise DGLError(
' Got %d and %d instead.' % (nfeats, num_nodes)) "Expect number of features to match number of nodes (len(u))."
" Got %d and %d instead." % (nfeats, num_nodes)
)
if F.context(val) != self.device: if F.context(val) != self.device:
raise DGLError('Cannot assign node feature "{}" on device {} to a graph on' raise DGLError(
' device {}. Call DGLGraph.to() to copy the graph to the' 'Cannot assign node feature "{}" on device {} to a graph on'
' same device.'.format(key, F.context(val), self.device)) " device {}. Call DGLGraph.to() to copy the graph to the"
" same device.".format(key, F.context(val), self.device)
)
# To prevent users from doing things like: # To prevent users from doing things like:
# #
# g.pin_memory_() # g.pin_memory_()
# g.ndata['x'] = torch.randn(...) # g.ndata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda()) # sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing # sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if self.is_pinned() and F.context(val) == 'cpu' and not F.is_pinned(val): if (
raise DGLError('Pinned graph requires the node data to be pinned as well. ' self.is_pinned()
'Please pin the node data before assignment.') and F.context(val) == "cpu"
and not F.is_pinned(val)
):
raise DGLError(
"Pinned graph requires the node data to be pinned as well. "
"Please pin the node data before assignment."
)
if is_all(u): if is_all(u):
self._node_frames[ntid].update(data) self._node_frames[ntid].update(data)
...@@ -4070,7 +4330,7 @@ class DGLGraph(object): ...@@ -4070,7 +4330,7 @@ class DGLGraph(object):
if is_all(u): if is_all(u):
return self._node_frames[ntid] return self._node_frames[ntid]
else: else:
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
return self._node_frames[ntid].subframe(u) return self._node_frames[ntid].subframe(u)
def _pop_n_repr(self, ntid, key): def _pop_n_repr(self, ntid, key):
...@@ -4116,12 +4376,14 @@ class DGLGraph(object): ...@@ -4116,12 +4376,14 @@ class DGLGraph(object):
""" """
# parse argument # parse argument
if not is_all(edges): if not is_all(edges):
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
# sanity check # sanity check
if not utils.is_dict_like(data): if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.' raise DGLError(
' Got "%s" instead.' % type(data)) "Expect dictionary type for feature data."
' Got "%s" instead.' % type(data)
)
if is_all(edges): if is_all(edges):
num_edges = self._graph.number_of_edges(etid) num_edges = self._graph.number_of_edges(etid)
...@@ -4130,21 +4392,31 @@ class DGLGraph(object): ...@@ -4130,21 +4392,31 @@ class DGLGraph(object):
for key, val in data.items(): for key, val in data.items():
nfeats = F.shape(val)[0] nfeats = F.shape(val)[0]
if nfeats != num_edges: if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.' raise DGLError(
' Got %d and %d instead.' % (nfeats, num_edges)) "Expect number of features to match number of edges."
" Got %d and %d instead." % (nfeats, num_edges)
)
if F.context(val) != self.device: if F.context(val) != self.device:
raise DGLError('Cannot assign edge feature "{}" on device {} to a graph on' raise DGLError(
' device {}. Call DGLGraph.to() to copy the graph to the' 'Cannot assign edge feature "{}" on device {} to a graph on'
' same device.'.format(key, F.context(val), self.device)) " device {}. Call DGLGraph.to() to copy the graph to the"
" same device.".format(key, F.context(val), self.device)
)
# To prevent users from doing things like: # To prevent users from doing things like:
# #
# g.pin_memory_() # g.pin_memory_()
# g.edata['x'] = torch.randn(...) # g.edata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda()) # sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing # sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if self.is_pinned() and F.context(val) == 'cpu' and not F.is_pinned(val): if (
raise DGLError('Pinned graph requires the edge data to be pinned as well. ' self.is_pinned()
'Please pin the edge data before assignment.') and F.context(val) == "cpu"
and not F.is_pinned(val)
):
raise DGLError(
"Pinned graph requires the edge data to be pinned as well. "
"Please pin the edge data before assignment."
)
# set # set
if is_all(edges): if is_all(edges):
...@@ -4172,7 +4444,7 @@ class DGLGraph(object): ...@@ -4172,7 +4444,7 @@ class DGLGraph(object):
if is_all(edges): if is_all(edges):
return self._edge_frames[etid] return self._edge_frames[etid]
else: else:
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
return self._edge_frames[etid].subframe(eid) return self._edge_frames[etid].subframe(eid)
def _pop_e_repr(self, etid, key): def _pop_e_repr(self, etid, key):
...@@ -4256,7 +4528,7 @@ class DGLGraph(object): ...@@ -4256,7 +4528,7 @@ class DGLGraph(object):
if is_all(v): if is_all(v):
v_id = self.nodes(ntype) v_id = self.nodes(ntype)
else: else:
v_id = utils.prepare_tensor(self, v, 'v') v_id = utils.prepare_tensor(self, v, "v")
ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id) ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)
self._set_n_repr(ntid, v, ndata) self._set_n_repr(ntid, v, ndata)
...@@ -4348,14 +4620,16 @@ class DGLGraph(object): ...@@ -4348,14 +4620,16 @@ class DGLGraph(object):
g = self if etype is None else self[etype] g = self if etype is None else self[etype]
else: # heterogeneous graph with number of relation types > 1 else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(func): if not core.is_builtin(func):
raise DGLError("User defined functions are not yet " raise DGLError(
"User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. " "supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead.") "Please use (apply_edges(func), etype = rel) instead."
)
g = self g = self
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
else: else:
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
if core.is_builtin(func): if core.is_builtin(func):
if not is_all(eid): if not is_all(eid):
g = g.edge_subgraph(eid, relabel_nodes=False) g = g.edge_subgraph(eid, relabel_nodes=False)
...@@ -4375,12 +4649,9 @@ class DGLGraph(object): ...@@ -4375,12 +4649,9 @@ class DGLGraph(object):
edata_tensor[key] = out_tensor_tuples[etid] edata_tensor[key] = out_tensor_tuples[etid]
self._set_e_repr(etid, eid, edata_tensor) self._set_e_repr(etid, eid, edata_tensor)
def send_and_recv(self, def send_and_recv(
edges, self, edges, message_func, reduce_func, apply_node_func=None, etype=None
message_func, ):
reduce_func,
apply_node_func=None,
etype=None):
"""Send messages along the specified edges and reduce them on """Send messages along the specified edges and reduce them on
the destination nodes to update their features. the destination nodes to update their features.
...@@ -4493,7 +4764,7 @@ class DGLGraph(object): ...@@ -4493,7 +4764,7 @@ class DGLGraph(object):
_, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
etype = self.canonical_etypes[etid] etype = self.canonical_etypes[etid]
# edge IDs # edge IDs
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges') eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
if len(eid) == 0: if len(eid) == 0:
# no computation # no computation
return return
...@@ -4502,15 +4773,13 @@ class DGLGraph(object): ...@@ -4502,15 +4773,13 @@ class DGLGraph(object):
g = self if etype is None else self[etype] g = self if etype is None else self[etype]
compute_graph, _, dstnodes, _ = _create_compute_graph(g, u, v, eid) compute_graph, _, dstnodes, _ = _create_compute_graph(g, u, v, eid)
ndata = core.message_passing( ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func) compute_graph, message_func, reduce_func, apply_node_func
)
self._set_n_repr(dtid, dstnodes, ndata) self._set_n_repr(dtid, dstnodes, ndata)
def pull(self, def pull(
v, self, v, message_func, reduce_func, apply_node_func=None, etype=None
message_func, ):
reduce_func,
apply_node_func=None,
etype=None):
"""Pull messages from the specified node(s)' predecessors along the """Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features. specified edge type, aggregate them to update the node features.
...@@ -4588,7 +4857,7 @@ class DGLGraph(object): ...@@ -4588,7 +4857,7 @@ class DGLGraph(object):
[1.], [1.],
[1.]]) [1.]])
""" """
v = utils.prepare_tensor(self, v, 'v') v = utils.prepare_tensor(self, v, "v")
if len(v) == 0: if len(v) == 0:
# no computation # no computation
return return
...@@ -4597,18 +4866,18 @@ class DGLGraph(object): ...@@ -4597,18 +4866,18 @@ class DGLGraph(object):
etype = self.canonical_etypes[etid] etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype] g = self if etype is None else self[etype]
# call message passing on subgraph # call message passing on subgraph
src, dst, eid = g.in_edges(v, form='all') src, dst, eid = g.in_edges(v, form="all")
compute_graph, _, dstnodes, _ = _create_compute_graph(g, src, dst, eid, v) compute_graph, _, dstnodes, _ = _create_compute_graph(
g, src, dst, eid, v
)
ndata = core.message_passing( ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func) compute_graph, message_func, reduce_func, apply_node_func
)
self._set_n_repr(dtid, dstnodes, ndata) self._set_n_repr(dtid, dstnodes, ndata)
def push(self, def push(
u, self, u, message_func, reduce_func, apply_node_func=None, etype=None
message_func, ):
reduce_func,
apply_node_func=None,
etype=None):
"""Send message from the specified node(s) to their successors """Send message from the specified node(s) to their successors
along the specified edge type and update their node features. along the specified edge type and update their node features.
...@@ -4679,14 +4948,14 @@ class DGLGraph(object): ...@@ -4679,14 +4948,14 @@ class DGLGraph(object):
[0.], [0.],
[0.]]) [0.]])
""" """
edges = self.out_edges(u, form='eid', etype=etype) edges = self.out_edges(u, form="eid", etype=etype)
self.send_and_recv(edges, message_func, reduce_func, apply_node_func, etype=etype) self.send_and_recv(
edges, message_func, reduce_func, apply_node_func, etype=etype
def update_all(self, )
message_func,
reduce_func, def update_all(
apply_node_func=None, self, message_func, reduce_func, apply_node_func=None, etype=None
etype=None): ):
"""Send messages along all the edges of the specified type """Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type. and update all the nodes of the corresponding destination type.
...@@ -4778,23 +5047,37 @@ class DGLGraph(object): ...@@ -4778,23 +5047,37 @@ class DGLGraph(object):
etype = self.canonical_etypes[etid] etype = self.canonical_etypes[etid]
_, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
g = self if etype is None else self[etype] g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func) ndata = core.message_passing(
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata: g, message_func, reduce_func, apply_node_func
)
if (
core.is_builtin(reduce_func)
and reduce_func.name in ["min", "max"]
and ndata
):
# Replace infinity with zero for isolated nodes # Replace infinity with zero for isolated nodes
key = list(ndata.keys())[0] key = list(ndata.keys())[0]
ndata[key] = F.replace_inf_with_zero(ndata[key]) ndata[key] = F.replace_inf_with_zero(ndata[key])
self._set_n_repr(dtid, ALL, ndata) self._set_n_repr(dtid, ALL, ndata)
else: # heterogeneous graph with number of relation types > 1 else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(message_func) or not core.is_builtin(reduce_func): if not core.is_builtin(message_func) or not core.is_builtin(
raise DGLError("User defined functions are not yet " reduce_func
):
raise DGLError(
"User defined functions are not yet "
"supported in update_all for heterogeneous graphs. " "supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.") "Please use multi_update_all instead."
if reduce_func.name in ['mean']: )
raise NotImplementedError("Cannot set both intra-type and inter-type reduce " if reduce_func.name in ["mean"]:
raise NotImplementedError(
"Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use " "operators as 'mean' using update_all. Please use "
"multi_update_all instead.") "multi_update_all instead."
)
g = self g = self
all_out = core.message_passing(g, message_func, reduce_func, apply_node_func) all_out = core.message_passing(
g, message_func, reduce_func, apply_node_func
)
key = list(all_out.keys())[0] key = list(all_out.keys())[0]
out_tensor_tuples = all_out[key] out_tensor_tuples = all_out[key]
...@@ -4802,7 +5085,10 @@ class DGLGraph(object): ...@@ -4802,7 +5085,10 @@ class DGLGraph(object):
for _, _, dsttype in g.canonical_etypes: for _, _, dsttype in g.canonical_etypes:
dtid = g.get_ntype_id(dsttype) dtid = g.get_ntype_id(dsttype)
dst_tensor[key] = out_tensor_tuples[dtid] dst_tensor[key] = out_tensor_tuples[dtid]
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max']: if core.is_builtin(reduce_func) and reduce_func.name in [
"min",
"max",
]:
dst_tensor[key] = F.replace_inf_with_zero(dst_tensor[key]) dst_tensor[key] = F.replace_inf_with_zero(dst_tensor[key])
self._node_frames[dtid].update(dst_tensor) self._node_frames[dtid].update(dst_tensor)
...@@ -4902,36 +5188,44 @@ class DGLGraph(object): ...@@ -4902,36 +5188,44 @@ class DGLGraph(object):
_, dtid = self._graph.metagraph.find_edge(etid) _, dtid = self._graph.metagraph.find_edge(etid)
args = pad_tuple(args, 3) args = pad_tuple(args, 3)
if args is None: if args is None:
raise DGLError('Invalid arguments for edge type "{}". Should be ' raise DGLError(
'(msg_func, reduce_func, [apply_node_func])'.format(etype)) 'Invalid arguments for edge type "{}". Should be '
"(msg_func, reduce_func, [apply_node_func])".format(etype)
)
mfunc, rfunc, afunc = args mfunc, rfunc, afunc = args
g = self if etype is None else self[etype] g = self if etype is None else self[etype]
all_out[dtid].append(core.message_passing(g, mfunc, rfunc, afunc)) all_out[dtid].append(core.message_passing(g, mfunc, rfunc, afunc))
merge_order[dtid].append(etid) # use edge type id as merge order hint merge_order[dtid].append(
etid
) # use edge type id as merge order hint
for dtid, frames in all_out.items(): for dtid, frames in all_out.items():
# merge by cross_reducer # merge by cross_reducer
out = reduce_dict_data(frames, cross_reducer, merge_order[dtid]) out = reduce_dict_data(frames, cross_reducer, merge_order[dtid])
# Replace infinity with zero for isolated nodes when reducer is min/max # Replace infinity with zero for isolated nodes when reducer is min/max
if core.is_builtin(rfunc) and rfunc.name in ['min', 'max']: if core.is_builtin(rfunc) and rfunc.name in ["min", "max"]:
key = list(out.keys())[0] key = list(out.keys())[0]
out[key] = F.replace_inf_with_zero(out[key]) if out[key] is not None else None out[key] = (
F.replace_inf_with_zero(out[key])
if out[key] is not None
else None
)
self._node_frames[dtid].update(out) self._node_frames[dtid].update(out)
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid]) self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])
################################################################# #################################################################
# Message propagation # Message propagation
################################################################# #################################################################
def prop_nodes(self, def prop_nodes(
self,
nodes_generator, nodes_generator,
message_func, message_func,
reduce_func, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None): etype=None,
):
"""Propagate messages using graph traversal by sequentially triggering """Propagate messages using graph traversal by sequentially triggering
:func:`pull()` on nodes. :func:`pull()` on nodes.
...@@ -4987,14 +5281,22 @@ class DGLGraph(object): ...@@ -4987,14 +5281,22 @@ class DGLGraph(object):
prop_edges prop_edges
""" """
for node_frontier in nodes_generator: for node_frontier in nodes_generator:
self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype) self.pull(
node_frontier,
message_func,
reduce_func,
apply_node_func,
etype=etype,
)
def prop_edges(self, def prop_edges(
self,
edges_generator, edges_generator,
message_func, message_func,
reduce_func, reduce_func,
apply_node_func=None, apply_node_func=None,
etype=None): etype=None,
):
"""Propagate messages using graph traversal by sequentially triggering """Propagate messages using graph traversal by sequentially triggering
:func:`send_and_recv()` on edges. :func:`send_and_recv()` on edges.
...@@ -5051,8 +5353,13 @@ class DGLGraph(object): ...@@ -5051,8 +5353,13 @@ class DGLGraph(object):
prop_nodes prop_nodes
""" """
for edge_frontier in edges_generator: for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, message_func, reduce_func, self.send_and_recv(
apply_node_func, etype=etype) edge_frontier,
message_func,
reduce_func,
apply_node_func,
etype=etype,
)
################################################################# #################################################################
# Misc # Misc
...@@ -5127,14 +5434,16 @@ class DGLGraph(object): ...@@ -5127,14 +5434,16 @@ class DGLGraph(object):
""" """
if is_all(nodes): if is_all(nodes):
nodes = self.nodes(ntype) nodes = self.nodes(ntype)
v = utils.prepare_tensor(self, nodes, 'nodes') v = utils.prepare_tensor(self, nodes, "nodes")
if F.as_scalar(F.sum(self.has_nodes(v, ntype=ntype), dim=0)) != len(v): if F.as_scalar(F.sum(self.has_nodes(v, ntype=ntype), dim=0)) != len(v):
raise DGLError('v contains invalid node IDs') raise DGLError("v contains invalid node IDs")
with self.local_scope(): with self.local_scope():
self.apply_nodes(lambda nbatch: {'_mask' : predicate(nbatch)}, nodes, ntype) self.apply_nodes(
lambda nbatch: {"_mask": predicate(nbatch)}, nodes, ntype
)
ntype = self.ntypes[0] if ntype is None else ntype ntype = self.ntypes[0] if ntype is None else ntype
mask = self.nodes[ntype].data['_mask'] mask = self.nodes[ntype].data["_mask"]
if is_all(nodes): if is_all(nodes):
return F.nonzero_1d(mask) return F.nonzero_1d(mask)
else: else:
...@@ -5221,34 +5530,40 @@ class DGLGraph(object): ...@@ -5221,34 +5530,40 @@ class DGLGraph(object):
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
srctype, _, dsttype = self.to_canonical_etype(etype) srctype, _, dsttype = self.to_canonical_etype(etype)
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, "u")
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u): if F.as_scalar(
raise DGLError('edges[0] contains invalid node IDs') F.sum(self.has_nodes(u, ntype=srctype), dim=0)
v = utils.prepare_tensor(self, v, 'v') ) != len(u):
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v): raise DGLError("edges[0] contains invalid node IDs")
raise DGLError('edges[1] contains invalid node IDs') v = utils.prepare_tensor(self, v, "v")
if F.as_scalar(
F.sum(self.has_nodes(v, ntype=dsttype), dim=0)
) != len(v):
raise DGLError("edges[1] contains invalid node IDs")
elif isinstance(edges, Iterable) or F.is_tensor(edges): elif isinstance(edges, Iterable) or F.is_tensor(edges):
edges = utils.prepare_tensor(self, edges, 'edges') edges = utils.prepare_tensor(self, edges, "edges")
min_eid = F.as_scalar(F.min(edges, 0)) min_eid = F.as_scalar(F.min(edges, 0))
if len(edges) > 0 > min_eid: if len(edges) > 0 > min_eid:
raise DGLError('Invalid edge ID {:d}'.format(min_eid)) raise DGLError("Invalid edge ID {:d}".format(min_eid))
max_eid = F.as_scalar(F.max(edges, 0)) max_eid = F.as_scalar(F.max(edges, 0))
if len(edges) > 0 and max_eid >= self.num_edges(etype): if len(edges) > 0 and max_eid >= self.num_edges(etype):
raise DGLError('Invalid edge ID {:d}'.format(max_eid)) raise DGLError("Invalid edge ID {:d}".format(max_eid))
else: else:
raise ValueError('Unsupported type of edges:', type(edges)) raise ValueError("Unsupported type of edges:", type(edges))
with self.local_scope(): with self.local_scope():
self.apply_edges(lambda ebatch: {'_mask' : predicate(ebatch)}, edges, etype) self.apply_edges(
lambda ebatch: {"_mask": predicate(ebatch)}, edges, etype
)
etype = self.canonical_etypes[0] if etype is None else etype etype = self.canonical_etypes[0] if etype is None else etype
mask = self.edges[etype].data['_mask'] mask = self.edges[etype].data["_mask"]
if is_all(edges): if is_all(edges):
return F.nonzero_1d(mask) return F.nonzero_1d(mask)
else: else:
if isinstance(edges, tuple): if isinstance(edges, tuple):
e = self.edge_ids(edges[0], edges[1], etype=etype) e = self.edge_ids(edges[0], edges[1], etype=etype)
else: else:
e = utils.prepare_tensor(self, edges, 'edges') e = utils.prepare_tensor(self, edges, "edges")
return F.boolean_mask(e, F.gather_row(mask, e)) return F.boolean_mask(e, F.gather_row(mask, e))
@property @property
...@@ -5347,12 +5662,16 @@ class DGLGraph(object): ...@@ -5347,12 +5662,16 @@ class DGLGraph(object):
# 2. Copy misc info # 2. Copy misc info
if self._batch_num_nodes is not None: if self._batch_num_nodes is not None:
new_bnn = {k : F.copy_to(num, device, **kwargs) new_bnn = {
for k, num in self._batch_num_nodes.items()} k: F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_nodes.items()
}
ret._batch_num_nodes = new_bnn ret._batch_num_nodes = new_bnn
if self._batch_num_edges is not None: if self._batch_num_edges is not None:
new_bne = {k : F.copy_to(num, device, **kwargs) new_bne = {
for k, num in self._batch_num_edges.items()} k: F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_edges.items()
}
ret._batch_num_edges = new_bne ret._batch_num_edges = new_bne
return ret return ret
...@@ -5432,8 +5751,10 @@ class DGLGraph(object): ...@@ -5432,8 +5751,10 @@ class DGLGraph(object):
tensor([0, 1, 1]) tensor([0, 1, 1])
""" """
if not self._graph.is_pinned(): if not self._graph.is_pinned():
if F.device_type(self.device) != 'cpu': if F.device_type(self.device) != "cpu":
raise DGLError("The graph structure must be on CPU to be pinned.") raise DGLError(
"The graph structure must be on CPU to be pinned."
)
self._graph.pin_memory_() self._graph.pin_memory_()
for frame in itertools.chain(self._node_frames, self._edge_frames): for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values(): for col in frame._columns.values():
...@@ -5484,9 +5805,9 @@ class DGLGraph(object): ...@@ -5484,9 +5805,9 @@ class DGLGraph(object):
DGLGraph DGLGraph
self. self.
""" """
if F.get_preferred_backend() != 'pytorch': if F.get_preferred_backend() != "pytorch":
raise DGLError("record_stream only support the PyTorch backend.") raise DGLError("record_stream only support the PyTorch backend.")
if F.device_type(self.device) != 'cuda': if F.device_type(self.device) != "cuda":
raise DGLError("The graph must be on GPU to be recorded.") raise DGLError("The graph must be on GPU to be recorded.")
self._graph.record_stream(stream) self._graph.record_stream(stream)
for frame in itertools.chain(self._node_frames, self._edge_frames): for frame in itertools.chain(self._node_frames, self._edge_frames):
...@@ -5510,15 +5831,24 @@ class DGLGraph(object): ...@@ -5510,15 +5831,24 @@ class DGLGraph(object):
# Clone the graph structure # Clone the graph structure
meta_edges = [] meta_edges = []
for s_ntype, _, d_ntype in self.canonical_etypes: for s_ntype, _, d_ntype in self.canonical_etypes:
meta_edges.append((self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype))) meta_edges.append(
(self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype))
)
metagraph = graph_index.from_edge_list(meta_edges, True) metagraph = graph_index.from_edge_list(meta_edges, True)
# rebuild graph idx # rebuild graph idx
num_nodes_per_type = [self.number_of_nodes(c_ntype) for c_ntype in self.ntypes] num_nodes_per_type = [
relation_graphs = [self._graph.get_relation_graph(self.get_etype_id(c_etype)) self.number_of_nodes(c_ntype) for c_ntype in self.ntypes
for c_etype in self.canonical_etypes] ]
relation_graphs = [
self._graph.get_relation_graph(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes
]
ret._graph = heterograph_index.create_heterograph_from_relations( ret._graph = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64")) metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
# Clone the frames # Clone the frames
ret._node_frames = [fr.clone() for fr in self._node_frames] ret._node_frames = [fr.clone() for fr in self._node_frames]
...@@ -5819,7 +6149,7 @@ class DGLGraph(object): ...@@ -5819,7 +6149,7 @@ class DGLGraph(object):
return ret return ret
# TODO: Formats should not be specified, just saving all the materialized formats # TODO: Formats should not be specified, just saving all the materialized formats
def shared_memory(self, name, formats=('coo', 'csr', 'csc')): def shared_memory(self, name, formats=("coo", "csr", "csc")):
"""Return a copy of this graph in shared memory, without node data or edge data. """Return a copy of this graph in shared memory, without node data or edge data.
It moves the graph index to shared memory and returns a DGLGraph object which It moves the graph index to shared memory and returns a DGLGraph object which
...@@ -5843,11 +6173,16 @@ class DGLGraph(object): ...@@ -5843,11 +6173,16 @@ class DGLGraph(object):
if isinstance(formats, str): if isinstance(formats, str):
formats = [formats] formats = [formats]
for fmt in formats: for fmt in formats:
assert fmt in ("coo", "csr", "csc"), '{} is not coo, csr or csc'.format(fmt) assert fmt in (
gidx = self._graph.shared_memory(name, self.ntypes, self.etypes, formats) "coo",
"csr",
"csc",
), "{} is not coo, csr or csc".format(fmt)
gidx = self._graph.shared_memory(
name, self.ntypes, self.etypes, formats
)
return DGLGraph(gidx, self.ntypes, self.etypes) return DGLGraph(gidx, self.ntypes, self.etypes)
def long(self): def long(self):
"""Cast the graph to one with idtype int64 """Cast the graph to one with idtype int64
...@@ -5948,10 +6283,12 @@ class DGLGraph(object): ...@@ -5948,10 +6283,12 @@ class DGLGraph(object):
""" """
return self.astype(F.int32) return self.astype(F.int32)
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
def make_canonical_etypes(etypes, ntypes, metagraph): def make_canonical_etypes(etypes, ntypes, metagraph):
"""Internal function to convert etype name to (srctype, etype, dsttype) """Internal function to convert etype name to (srctype, etype, dsttype)
...@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph): ...@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
""" """
# sanity check # sanity check
if len(etypes) != metagraph.number_of_edges(): if len(etypes) != metagraph.number_of_edges():
raise DGLError('Length of edge type list must match the number of ' raise DGLError(
'edges in the metagraph. {} vs {}'.format( "Length of edge type list must match the number of "
len(etypes), metagraph.number_of_edges())) "edges in the metagraph. {} vs {}".format(
len(etypes), metagraph.number_of_edges()
)
)
if len(ntypes) != metagraph.number_of_nodes(): if len(ntypes) != metagraph.number_of_nodes():
raise DGLError('Length of nodes type list must match the number of ' raise DGLError(
'nodes in the metagraph. {} vs {}'.format( "Length of nodes type list must match the number of "
len(ntypes), metagraph.number_of_nodes())) "nodes in the metagraph. {} vs {}".format(
if (len(etypes) == 1 and len(ntypes) == 1): len(ntypes), metagraph.number_of_nodes()
)
)
if len(etypes) == 1 and len(ntypes) == 1:
return [(ntypes[0], etypes[0], ntypes[0])] return [(ntypes[0], etypes[0], ntypes[0])]
src, dst, eid = metagraph.edges(order="eid") src, dst, eid = metagraph.edges(order="eid")
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)] rst = [
(ntypes[sid], etypes[eid], ntypes[did])
for sid, did, eid in zip(src, dst, eid)
]
return rst return rst
def find_src_dst_ntypes(ntypes, metagraph): def find_src_dst_ntypes(ntypes, metagraph):
"""Internal function to split ntypes into SRC and DST categories. """Internal function to split ntypes into SRC and DST categories.
...@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph): ...@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph):
return None return None
else: else:
src, dst = ret src, dst = ret
srctypes = {ntypes[tid] : tid for tid in src} srctypes = {ntypes[tid]: tid for tid in src}
dsttypes = {ntypes[tid] : tid for tid in dst} dsttypes = {ntypes[tid]: tid for tid in dst}
return srctypes, dsttypes return srctypes, dsttypes
def pad_tuple(tup, length, pad_val=None): def pad_tuple(tup, length, pad_val=None):
"""Pad the given tuple to the given length. """Pad the given tuple to the given length.
...@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None): ...@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None):
Return None if pad fails. Return None if pad fails.
""" """
if not isinstance(tup, tuple): if not isinstance(tup, tuple):
tup = (tup, ) tup = (tup,)
if len(tup) > length: if len(tup) > length:
return None return None
elif len(tup) == length: elif len(tup) == length:
...@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None): ...@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None):
else: else:
return tup + (pad_val,) * (length - len(tup)) return tup + (pad_val,) * (length - len(tup))
def reduce_dict_data(frames, reducer, order=None): def reduce_dict_data(frames, reducer, order=None):
"""Merge tensor dictionaries into one. Resolve conflict fields using reducer. """Merge tensor dictionaries into one. Resolve conflict fields using reducer.
...@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None): ...@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None):
dict[str, Tensor] dict[str, Tensor]
Merged frame Merged frame
""" """
if len(frames) == 1 and reducer != 'stack': if len(frames) == 1 and reducer != "stack":
# Directly return the only one input. Stack reducer requires # Directly return the only one input. Stack reducer requires
# modifying tensor shape. # modifying tensor shape.
return frames[0] return frames[0]
if callable(reducer): if callable(reducer):
merger = reducer merger = reducer
elif reducer == 'stack': elif reducer == "stack":
# Stack order does not matter. However, it must be consistent! # Stack order does not matter. However, it must be consistent!
if order: if order:
assert len(order) == len(frames) assert len(order) == len(frames)
sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1]) sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
frames = list(zip(*sorted_with_key))[0] frames = list(zip(*sorted_with_key))[0]
def merger(flist): def merger(flist):
return F.stack(flist, 1) return F.stack(flist, 1)
else: else:
redfn = getattr(F, reducer, None) redfn = getattr(F, reducer, None)
if redfn is None: if redfn is None:
raise DGLError('Invalid cross type reducer. Must be one of ' raise DGLError(
'"sum", "max", "min", "mean" or "stack".') "Invalid cross type reducer. Must be one of "
'"sum", "max", "min", "mean" or "stack".'
)
def merger(flist): def merger(flist):
return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0] return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
keys = set() keys = set()
for frm in frames: for frm in frames:
keys.update(frm.keys()) keys.update(frm.keys())
...@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None): ...@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None):
ret[k] = merger(flist) ret[k] = merger(flist)
return ret return ret
def combine_frames(frames, ids, col_names=None): def combine_frames(frames, ids, col_names=None):
"""Merge the frames into one frame, taking the common columns. """Merge the frames into one frame, taking the common columns.
...@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None): ...@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None):
for key, scheme in list(schemes.items()): for key, scheme in list(schemes.items()):
if key in frame.schemes: if key in frame.schemes:
if frame.schemes[key] != scheme: if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' % raise DGLError(
(key, frame.schemes[key], scheme)) "Cannot concatenate column %s with shape %s and shape %s"
% (key, frame.schemes[key], scheme)
)
else: else:
del schemes[key] del schemes[key]
...@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None): ...@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None):
cols = {key: F.cat(to_cat(key), dim=0) for key in schemes} cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
return Frame(cols) return Frame(cols)
def combine_names(names, ids=None): def combine_names(names, ids=None):
"""Combine the selected names into one new name. """Combine the selected names into one new name.
...@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None): ...@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None):
str str
""" """
if ids is None: if ids is None:
return '+'.join(sorted(names)) return "+".join(sorted(names))
else: else:
selected = sorted([names[i] for i in ids]) selected = sorted([names[i] for i in ids])
return '+'.join(selected) return "+".join(selected)
class DGLBlock(DGLGraph): class DGLBlock(DGLGraph):
"""Subclass that signifies the graph is a block created from """Subclass that signifies the graph is a block created from
:func:`dgl.to_block`. :func:`dgl.to_block`.
""" """
# (BarclayII) I'm making a subclass because I don't want to make another version of # (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag. # serialization that contains the is_block flag.
is_block = True is_block = True
def __repr__(self): def __repr__(self):
if len(self.srctypes) == 1 and len(self.dsttypes) == 1 and len(self.etypes) == 1: if (
ret = 'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})' len(self.srctypes) == 1
and len(self.dsttypes) == 1
and len(self.etypes) == 1
):
ret = "Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})"
return ret.format( return ret.format(
srcnode=self.number_of_src_nodes(), srcnode=self.number_of_src_nodes(),
dstnode=self.number_of_dst_nodes(), dstnode=self.number_of_dst_nodes(),
edge=self.number_of_edges()) edge=self.number_of_edges(),
)
else: else:
ret = ('Block(num_src_nodes={srcnode},\n' ret = (
' num_dst_nodes={dstnode},\n' "Block(num_src_nodes={srcnode},\n"
' num_edges={edge},\n' " num_dst_nodes={dstnode},\n"
' metagraph={meta})') " num_edges={edge},\n"
nsrcnode_dict = {ntype : self.number_of_src_nodes(ntype) " metagraph={meta})"
for ntype in self.srctypes} )
ndstnode_dict = {ntype : self.number_of_dst_nodes(ntype) nsrcnode_dict = {
for ntype in self.dsttypes} ntype: self.number_of_src_nodes(ntype)
nedge_dict = {etype : self.number_of_edges(etype) for ntype in self.srctypes
for etype in self.canonical_etypes} }
ndstnode_dict = {
ntype: self.number_of_dst_nodes(ntype)
for ntype in self.dsttypes
}
nedge_dict = {
etype: self.number_of_edges(etype)
for etype in self.canonical_etypes
}
meta = str(self.metagraph().edges(keys=True)) meta = str(self.metagraph().edges(keys=True))
return ret.format( return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta) srcnode=nsrcnode_dict,
dstnode=ndstnode_dict,
edge=nedge_dict,
meta=meta,
)
def _create_compute_graph(graph, u, v, eid, recv_nodes=None): def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
...@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None): ...@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
srctype, etype, dsttype = graph.canonical_etypes[0] srctype, etype, dsttype = graph.canonical_etypes[0]
# create graph # create graph
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
2, len(unique_src), len(unique_dst), new_u, new_v, ['coo', 'csr', 'csc']) 2, len(unique_src), len(unique_dst), new_u, new_v, ["coo", "csr", "csc"]
)
# create frame # create frame
srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(unique_src) srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(
unique_src
)
srcframe[NID] = unique_src srcframe[NID] = unique_src
dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(unique_dst) dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(
unique_dst
)
dstframe[NID] = unique_dst dstframe[NID] = unique_dst
eframe = graph._edge_frames[0].subframe(eid) eframe = graph._edge_frames[0].subframe(eid)
eframe[EID] = eid eframe[EID] = eid
return DGLGraph(hgidx, ([srctype], [dsttype]), [etype], return (
DGLGraph(
hgidx,
([srctype], [dsttype]),
[etype],
node_frames=[srcframe, dstframe], node_frames=[srcframe, dstframe],
edge_frames=[eframe]), unique_src, unique_dst, eid edge_frames=[eframe],
),
unique_src,
unique_dst,
eid,
)
_init_api("dgl.heterograph") _init_api("dgl.heterograph")
...@@ -7,12 +7,11 @@ import sys ...@@ -7,12 +7,11 @@ import sys
import numpy as np import numpy as np
import scipy import scipy
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from ._ffi.streams import to_dgl_stream_handle from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning from .base import dgl_warning, DGLError
from .graph_index import from_coo from .graph_index import from_coo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment