Unverified Commit d1827488 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 3e5137fe
"""Base classes and functionalities for dataloaders"""
from collections.abc import Mapping
import inspect
from ..base import NID, EID
from ..convert import heterograph
from collections.abc import Mapping
from .. import backend as F
from ..transforms import compact_graphs
from ..base import EID, NID
from ..convert import heterograph
from ..frame import LazyFeature
from ..utils import recursive_apply, context_of
from ..transforms import compact_graphs
from ..utils import context_of, recursive_apply
def _set_lazy_features(x, xdata, feature_names):
if feature_names is None:
......@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for type_, names in feature_names.items():
x[type_].data.update({k: LazyFeature(k) for k in names})
def set_node_lazy_features(g, feature_names):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
......@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.nodes, g.ndata, feature_names)
def set_edge_lazy_features(g, feature_names):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
......@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.edges, g.edata, feature_names)
def set_src_lazy_features(g, feature_names):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
......@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.srcnodes, g.srcdata, feature_names)
def set_dst_lazy_features(g, feature_names):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
......@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
"""
return _set_lazy_features(g.dstnodes, g.dstdata, feature_names)
class Sampler(object):
"""Base class for graph samplers.
......@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices):
return g.subgraph(indices)
"""
def sample(self, g, indices):
"""Abstract sample method.
......@@ -183,6 +191,7 @@ class Sampler(object):
"""
raise NotImplementedError
class BlockSampler(Sampler):
"""Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs).
......@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
"""
def __init__(self, prefetch_node_feats=None, prefetch_labels=None,
prefetch_edge_feats=None, output_device=None):
def __init__(
self,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__()
self.prefetch_node_feats = prefetch_node_feats or []
self.prefetch_labels = prefetch_labels or []
......@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features(block, self.prefetch_edge_feats)
return input_nodes, output_nodes, blocks
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ
def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes."""
result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)
return self.assign_lazy_features(result)
......@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()}
for k, v in eids.items()
}
else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()}
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()})
for k, v in reverse_etype_map.items()
}
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
if exclude_mode is None:
return None
elif callable(exclude_mode):
return exclude_mode(eids)
elif F.is_tensor(exclude_mode) or (
isinstance(exclude_mode, Mapping) and
all(F.is_tensor(v) for v in exclude_mode.values())):
isinstance(exclude_mode, Mapping)
and all(F.is_tensor(v) for v in exclude_mode.values())
):
return exclude_mode
elif exclude_mode == 'self':
elif exclude_mode == "self":
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map'])
elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(
g, eids, kwargs["reverse_eid_map"]
)
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else:
raise ValueError('unsupported mode {}'.format(exclude_mode))
raise ValueError("unsupported mode {}".format(exclude_mode))
def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=None,
output_device=None):
def find_exclude_eids(
g,
seed_edges,
exclude,
reverse_eids=None,
reverse_etypes=None,
output_device=None,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters
......@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude,
seed_edges,
reverse_eid_map=reverse_eids,
reverse_etype_map=reverse_etypes)
reverse_etype_map=reverse_etypes,
)
if exclude_eids is not None and output_device is not None:
exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device))
exclude_eids = recursive_apply(
exclude_eids, lambda x: F.copy_to(x, output_device)
)
return exclude_eids
class EdgePredictionSampler(Sampler):
"""Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction.
......@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
--------
as_edge_prediction_sampler
"""
def __init__(self, sampler, exclude=None, reverse_eids=None,
reverse_etypes=None, negative_sampler=None, prefetch_labels=None):
def __init__(
self,
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
super().__init__()
# Check if the sampler's sample method has an optional third argument.
argspec = inspect.getfullargspec(sampler.sample)
if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids']
if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids']
raise TypeError(
"This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method.")
"optional third argument for edge IDs to exclude in its sample() method."
)
self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes
self.exclude = exclude
......@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def _build_neg_graph(self, g, seed_edges):
neg_srcdst = self.negative_sampler(g, seed_edges)
if not isinstance(neg_srcdst, Mapping):
assert len(g.canonical_etypes) == 1, \
'graph has multiple or no edge types; '\
'please return a dict in negative sampler.'
assert len(g.canonical_etypes) == 1, (
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = context_of(seed_edges) if seed_edges is not None else g.device
neg_edges = {
etype: neg_srcdst.get(etype,
(F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx)))
for etype in g.canonical_etypes}
etype: neg_srcdst.get(
etype,
(
F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx),
),
)
for etype in g.canonical_etypes
}
neg_pair_graph = heterograph(
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes})
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
)
return neg_pair_graph
def assign_lazy_features(self, result):
......@@ -390,7 +445,7 @@ class EdgePredictionSampler(Sampler):
# In-place updates
return result
def sample(self, g, seed_edges): # pylint: disable=arguments-differ
def sample(self, g, seed_edges): # pylint: disable=arguments-differ
"""Samples a list of blocks, as well as a subgraph containing the sampled
edges from the original graph.
......@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges.
"""
if isinstance(seed_edges, Mapping):
seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()}
seed_edges = {
g.to_canonical_etype(k): v for k, v in seed_edges.items()
}
exclude = self.exclude
pair_graph = g.edge_subgraph(
seed_edges, relabel_nodes=False, output_device=self.output_device)
seed_edges, relabel_nodes=False, output_device=self.output_device
)
eids = pair_graph.edata[EID]
if self.negative_sampler is not None:
......@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes = pair_graph.ndata[NID]
exclude_eids = find_exclude_eids(
g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes,
self.output_device)
input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids)
g,
seed_edges,
exclude,
self.reverse_eids,
self.reverse_etypes,
self.output_device,
)
input_nodes, _, blocks = self.sampler.sample(
g, seed_nodes, exclude_eids
)
if self.negative_sampler is None:
return self.assign_lazy_features((input_nodes, pair_graph, blocks))
else:
return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks))
return self.assign_lazy_features(
(input_nodes, pair_graph, neg_graph, blocks)
)
def as_edge_prediction_sampler(
sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None,
prefetch_labels=None):
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
"""Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to
......@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
"""
return EdgePredictionSampler(
sampler, exclude=exclude, reverse_eids=reverse_eids, reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler, prefetch_labels=prefetch_labels)
sampler,
exclude=exclude,
reverse_eids=reverse_eids,
reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler,
prefetch_labels=prefetch_labels,
)
"""Distributed dataloaders.
"""
import inspect
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Mapping
from abc import ABC, abstractproperty, abstractmethod
from .. import transforms
from ..base import NID, EID
from .. import backend as F
from .. import utils
from .. import backend as F, transforms, utils
from ..base import EID, NID
from ..convert import heterograph
from ..distributed import DistDataLoader
......@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()}
for k, v in eids.items()
}
else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()}
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()})
for k, v in reverse_etype_map.items()
}
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
......@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""
if exclude_mode is None:
return None
elif exclude_mode == 'self':
elif exclude_mode == "self":
return eids
elif exclude_mode == 'reverse_id':
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map'])
elif exclude_mode == 'reverse_types':
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map'])
elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(
g, eids, kwargs["reverse_eid_map"]
)
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else:
raise ValueError('unsupported mode {}'.format(exclude_mode))
raise ValueError("unsupported mode {}".format(exclude_mode))
class Collator(ABC):
......@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
@abstractproperty
def dataset(self):
"""Returns the dataset object of the collator."""
......@@ -122,6 +132,7 @@ class Collator(ABC):
"""
raise NotImplementedError
class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling.
......@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, nids, graph_sampler):
self.g = g
if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \
"nids should be a dict of node type and ids for graph with multiple node types"
assert (
len(g.ntypes) == 1
), "nids should be a dict of node type and ids for graph with multiple node types"
self.graph_sampler = graph_sampler
self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids')
self.nids = utils.prepare_tensor_or_dict(g, nids, "nids")
self._dataset = utils.maybe_flatten_dict(self.nids)
@property
......@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g, items, 'items')
items = utils.prepare_tensor_or_dict(self.g, items, "items")
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items)
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(
self.g, items
)
return input_nodes, output_nodes, blocks
class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph
......@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None):
def __init__(
self,
g,
eids,
graph_sampler,
g_sampling=None,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
):
self.g = g
if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \
"eids should be a dict of etype and ids for graph with multiple etypes"
assert (
len(g.etypes) == 1
), "eids should be a dict of etype and ids for graph with multiple etypes"
self.graph_sampler = graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in
......@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self.reverse_etypes = reverse_etypes
self.negative_sampler = negative_sampler
self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids')
self.eids = utils.prepare_tensor_or_dict(g, eids, "eids")
self._dataset = utils.maybe_flatten_dict(self.eids)
@property
......@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID]
......@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, blocks
......@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items')
items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items)
if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \
'graph has multiple or no edge types; '\
'please return a dict in negative sampler.'
assert len(self.g.etypes) == 1, (
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}
# Get dtype from a tuple of tensors
dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = F.context(pair_graph)
neg_edges = {
etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx)))
for etype in self.g.canonical_etypes}
etype: neg_srcdst.get(
etype,
(
F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx),
),
)
for etype in self.g.canonical_etypes
}
neg_pair_graph = heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes})
neg_edges,
{ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes},
)
pair_graph, neg_pair_graph = transforms.compact_graphs([pair_graph, neg_pair_graph])
pair_graph, neg_pair_graph = transforms.compact_graphs(
[pair_graph, neg_pair_graph]
)
pair_graph.edata[EID] = induced_edges
seed_nodes = pair_graph.ndata[NID]
......@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self.exclude,
items,
reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes)
reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids)
self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, neg_pair_graph, blocks
......@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs:
del kwargs['num_workers']
if 'pin_memory' in kwargs:
del kwargs['pin_memory']
print('Distributed DataLoaders do not support pin_memory.')
if "num_workers" in kwargs:
del kwargs["num_workers"]
if "pin_memory" in kwargs:
del kwargs["pin_memory"]
print("Distributed DataLoaders do not support pin_memory.")
return kwargs
class DistNodeDataLoader(DistDataLoader):
"""Sampled graph data loader over nodes for distributed graph storage.
......@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def __init__(self, g, nids, graph_sampler, device=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
......@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs[k] = v
if device is None:
# for the distributed case default to the CPU
device = 'cpu'
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
device = "cpu"
assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs
)
self.device = device
class DistEdgeDataLoader(DistDataLoader):
"""Sampled graph data loader over edges for distributed graph storage.
......@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
......@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if device is None:
# for the distributed case default to the CPU
device = 'cpu'
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
device = "cpu"
assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
super().__init__(
self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs
)
self.device = device
......@@ -17,11 +17,11 @@
#
"""Data loading components for labor sampling"""
from ..base import NID, EID
from .. import backend as F
from ..base import EID, NID
from ..random import choice
from ..transforms import to_block
from .base import BlockSampler
from ..random import choice
from .. import backend as F
class LaborSampler(BlockSampler):
......@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
)
block.edata[EID] = eid
if len(g.canonical_etypes) > 1:
for etype, importance in zip(
g.canonical_etypes, importances
):
for etype, importance in zip(g.canonical_etypes, importances):
if importance.shape[0] == block.num_edges(etype):
block.edata["edge_weights"][etype] = importance
elif importances[0].shape[0] == block.num_edges():
......
"""Data loading components for neighbor sampling"""
from ..base import NID, EID
from ..base import EID, NID
from ..transforms import to_block
from .base import BlockSampler
class NeighborSampler(BlockSampler):
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
......@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, fanouts, edge_dir='in', prob=None, mask=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,
output_device=None):
super().__init__(prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device)
def __init__(
self,
fanouts,
edge_dir="in",
prob=None,
mask=None,
replace=False,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__(
prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device,
)
self.fanouts = fanouts
self.edge_dir = edge_dir
if mask is not None and prob is not None:
raise ValueError(
'Mask and probability arguments are mutually exclusive. '
'Consider multiplying the probability with the mask '
'to achieve the same goal.')
"Mask and probability arguments are mutually exclusive. "
"Consider multiplying the probability with the mask "
"to achieve the same goal."
)
self.prob = prob or mask
self.replace = replace
......@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks = []
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob,
replace=self.replace, output_device=self.output_device,
exclude_edges=exclude_eids)
seed_nodes,
fanout,
edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
......@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return seed_nodes, output_nodes, blocks
MultiLayerNeighborSampler = NeighborSampler
class MultiLayerFullNeighborSampler(NeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
......@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, num_layers, **kwargs):
super().__init__([-1] * num_layers, **kwargs)
"""ShaDow-GNN subgraph samplers."""
from ..sampling.utils import EidExcluder
from .. import transforms
from ..base import NID
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler
from ..sampling.utils import EidExcluder
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ShaDowKHopSampler(Sampler):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
......@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
"""
def __init__(self, fanouts, replace=False, prob=None, prefetch_node_feats=None,
prefetch_edge_feats=None, output_device=None):
def __init__(
self,
fanouts,
replace=False,
prob=None,
prefetch_node_feats=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__()
self.fanouts = fanouts
self.replace = replace
......@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self.prefetch_edge_feats = prefetch_edge_feats
self.output_device = output_device
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ
def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sampling function.
Parameters
......@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes = seed_nodes
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes, fanout, output_device=self.output_device,
replace=self.replace, prob=self.prob, exclude_edges=exclude_eids)
seed_nodes,
fanout,
output_device=self.output_device,
replace=self.replace,
prob=self.prob,
exclude_edges=exclude_eids,
)
block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device)
subg = g.subgraph(
seed_nodes, relabel_nodes=True, output_device=self.output_device
)
if exclude_eids is not None:
subg = EidExcluder(exclude_eids)(subg)
......
......@@ -7,8 +7,7 @@ from itertools import product
from .base import BuiltinFunction, TargetCode
__all__ = ["copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
__all__ = ["copy_u", "copy_e", "BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction):
......@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
--------
u_mul_e
"""
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op
self.lhs = lhs
......@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
--------
copy_u
"""
def __init__(self, target, in_field, out_field):
self.target = target
self.in_field = in_field
......@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
--------
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name)
""".format(
binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name,
)
def func(lhs_field, rhs_field, out):
return BinaryMessageFunction(
binary_op, _TARGET_MAP[lhs],
_TARGET_MAP[rhs], lhs_field, rhs_field, out)
binary_op,
_TARGET_MAP[lhs],
_TARGET_MAP[rhs],
lhs_field,
rhs_field,
out,
)
func.__name__ = name
func.__doc__ = docstring
return func
......@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_builtin_message_func()
"""Module for various graph generator functions."""
from . import backend as F
from . import convert, random
from . import backend as F, convert, random
__all__ = ["rand_graph", "rand_bipartite"]
......
"""Python interfaces to DGL farthest point sampler."""
import numpy as np
from .. import backend as F
from .. import ndarray as nd
from .. import backend as F, ndarray as nd
from .._ffi.base import DGLError
from .._ffi.function import _init_api
......
......@@ -5,11 +5,10 @@ import networkx as nx
import numpy as np
import scipy
from . import backend as F
from . import utils
from . import backend as F, utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning
from .base import dgl_warning, DGLError
class BoolFlag(object):
......
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from collections import defaultdict
from collections.abc import Mapping, Iterable
from contextlib import contextmanager
import copy
import numbers
import itertools
import numbers
# pylint: disable= too-many-lines
from collections import defaultdict
from collections.abc import Iterable, Mapping
from contextlib import contextmanager
import networkx as nx
import numpy as np
from . import backend as F, core, graph_index, heterograph_index, utils
from ._ffi.function import _init_api
from .ops import segment
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from . import core
from . import graph_index
from . import heterograph_index
from . import utils
from . import backend as F
from .base import (
ALL,
dgl_warning,
DGLError,
EID,
ETYPE,
is_all,
NID,
NTYPE,
SLICE_FULL,
)
from .frame import Frame
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .ops import segment
from .view import (
HeteroEdgeDataView,
HeteroEdgeView,
HeteroNodeDataView,
HeteroNodeView,
)
__all__ = ["DGLGraph", "combine_names"]
__all__ = ['DGLGraph', 'combine_names']
class DGLGraph(object):
"""Class for storing graph structure and node/edge feature data.
......@@ -35,16 +50,19 @@ class DGLGraph(object):
Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its
usage.
"""
is_block = False
# pylint: disable=unused-argument, dangerous-default-value
def __init__(self,
gidx=[],
ntypes=['_N'],
etypes=['_E'],
node_frames=None,
edge_frames=None,
**deprecate_kwargs):
def __init__(
self,
gidx=[],
ntypes=["_N"],
etypes=["_E"],
node_frames=None,
edge_frames=None,
**deprecate_kwargs
):
"""Internal constructor for creating a DGLGraph.
Parameters
......@@ -67,21 +85,42 @@ class DGLGraph(object):
of edge type i. (default: None)
"""
if isinstance(gidx, DGLGraph):
raise DGLError('The input is already a DGLGraph. No need to create it again.')
raise DGLError(
"The input is already a DGLGraph. No need to create it again."
)
if not isinstance(gidx, heterograph_index.HeteroGraphIndex):
dgl_warning('Recommend creating graphs by `dgl.graph(data)`'
' instead of `dgl.DGLGraph(data)`.')
(sparse_fmt, arrays), num_src, num_dst = utils.graphdata2tensors(gidx)
if sparse_fmt == 'coo':
dgl_warning(
"Recommend creating graphs by `dgl.graph(data)`"
" instead of `dgl.DGLGraph(data)`."
)
(sparse_fmt, arrays), num_src, num_dst = utils.graphdata2tensors(
gidx
)
if sparse_fmt == "coo":
gidx = heterograph_index.create_unitgraph_from_coo(
1, num_src, num_dst, arrays[0], arrays[1], ['coo', 'csr', 'csc'])
1,
num_src,
num_dst,
arrays[0],
arrays[1],
["coo", "csr", "csc"],
)
else:
gidx = heterograph_index.create_unitgraph_from_csr(
1, num_src, num_dst, arrays[0], arrays[1], arrays[2], ['coo', 'csr', 'csc'],
sparse_fmt == 'csc')
1,
num_src,
num_dst,
arrays[0],
arrays[1],
arrays[2],
["coo", "csr", "csc"],
sparse_fmt == "csc",
)
if len(deprecate_kwargs) != 0:
dgl_warning('Keyword arguments {} are deprecated in v0.5, and can be safely'
' removed in all cases.'.format(list(deprecate_kwargs.keys())))
dgl_warning(
"Keyword arguments {} are deprecated in v0.5, and can be safely"
" removed in all cases.".format(list(deprecate_kwargs.keys()))
)
self._init(gidx, ntypes, etypes, node_frames, edge_frames)
def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
......@@ -94,39 +133,51 @@ class DGLGraph(object):
# Handle node types
if isinstance(ntypes, tuple):
if len(ntypes) != 2:
errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format(
ntypes)
errmsg = "Invalid input. Expect a pair (srctypes, dsttypes) but got {}".format(
ntypes
)
raise TypeError(errmsg)
if not self._graph.is_metagraph_unibipartite():
raise ValueError('Invalid input. The metagraph must be a uni-directional'
' bipartite graph.')
raise ValueError(
"Invalid input. The metagraph must be a uni-directional"
" bipartite graph."
)
self._ntypes = ntypes[0] + ntypes[1]
self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
self._srctypes_invmap = {t: i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {
t: i + len(ntypes[0]) for i, t in enumerate(ntypes[1])
}
self._is_unibipartite = True
if len(ntypes[0]) == 1 and len(ntypes[1]) == 1 and len(etypes) == 1:
self._canonical_etypes = [(ntypes[0][0], etypes[0], ntypes[1][0])]
self._canonical_etypes = [
(ntypes[0][0], etypes[0], ntypes[1][0])
]
else:
self._ntypes = ntypes
if len(ntypes) == 1:
src_dst_map = None
else:
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
self._is_unibipartite = (src_dst_map is not None)
src_dst_map = find_src_dst_ntypes(
self._ntypes, self._graph.metagraph
)
self._is_unibipartite = src_dst_map is not None
if self._is_unibipartite:
self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
else:
self._srctypes_invmap = {t : i for i, t in enumerate(self._ntypes)}
self._srctypes_invmap = {
t: i for i, t in enumerate(self._ntypes)
}
self._dsttypes_invmap = self._srctypes_invmap
# Handle edge types
self._etypes = etypes
if self._canonical_etypes is None:
if (len(etypes) == 1 and len(ntypes) == 1):
if len(etypes) == 1 and len(ntypes) == 1:
self._canonical_etypes = [(ntypes[0], etypes[0], ntypes[0])]
else:
self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph)
self._etypes, self._ntypes, self._graph.metagraph
)
# An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate
......@@ -137,21 +188,29 @@ class DGLGraph(object):
self._etype2canonical[ety] = tuple()
else:
self._etype2canonical[ety] = self._canonical_etypes[i]
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
self._etypes_invmap = {
t: i for i, t in enumerate(self._canonical_etypes)
}
# node and edge frame
if node_frames is None:
node_frames = [None] * len(self._ntypes)
node_frames = [Frame(num_rows=self._graph.number_of_nodes(i))
if frame is None else frame
for i, frame in enumerate(node_frames)]
node_frames = [
Frame(num_rows=self._graph.number_of_nodes(i))
if frame is None
else frame
for i, frame in enumerate(node_frames)
]
self._node_frames = node_frames
if edge_frames is None:
edge_frames = [None] * len(self._etypes)
edge_frames = [Frame(num_rows=self._graph.number_of_edges(i))
if frame is None else frame
for i, frame in enumerate(edge_frames)]
edge_frames = [
Frame(num_rows=self._graph.number_of_edges(i))
if frame is None
else frame
for i, frame in enumerate(edge_frames)
]
self._edge_frames = edge_frames
def __setstate__(self, state):
......@@ -162,40 +221,60 @@ class DGLGraph(object):
self.__dict__.update(state)
elif isinstance(state, tuple) and len(state) == 5:
# DGL == 0.4.3
dgl_warning("The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored.")
dgl_warning(
"The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored."
)
self._init(*state)
elif isinstance(state, dict):
# DGL <= 0.4.2
dgl_warning("The object is pickled with DGL <= 0.4.2. "
"Some of the original attributes are ignored.")
self._init(state['_graph'], state['_ntypes'], state['_etypes'], state['_node_frames'],
state['_edge_frames'])
dgl_warning(
"The object is pickled with DGL <= 0.4.2. "
"Some of the original attributes are ignored."
)
self._init(
state["_graph"],
state["_ntypes"],
state["_etypes"],
state["_node_frames"],
state["_edge_frames"],
)
else:
raise IOError("Unrecognized pickle format.")
def __repr__(self):
if len(self.ntypes) == 1 and len(self.etypes) == 1:
ret = ('Graph(num_nodes={node}, num_edges={edge},\n'
' ndata_schemes={ndata}\n'
' edata_schemes={edata})')
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
ret = (
"Graph(num_nodes={node}, num_edges={edge},\n"
" ndata_schemes={ndata}\n"
" edata_schemes={edata})"
)
return ret.format(
node=self.number_of_nodes(),
edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()),
)
else:
ret = ('Graph(num_nodes={node},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))}
nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))}
ret = (
"Graph(num_nodes={node},\n"
" num_edges={edge},\n"
" metagraph={meta})"
)
nnode_dict = {
self.ntypes[i]: self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))
}
nedge_dict = {
self.canonical_etypes[i]: self._graph.number_of_edges(i)
for i in range(len(self.etypes))
}
meta = str(self.metagraph().edges(keys=True))
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
def __copy__(self):
"""Shallow copy implementation."""
#TODO(minjie): too many states in python; should clean up and lower to C
# TODO(minjie): too many states in python; should clean up and lower to C
cls = type(self)
obj = cls.__new__(cls)
obj.__dict__.update(self.__dict__)
......@@ -298,14 +377,16 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support add_nodes
if ntype is None:
if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
raise DGLError(
"Node type name must be specified if there are more than one "
"node types."
)
# nothing happen
if num == 0:
return
assert num > 0, 'Number of new nodes should be larger than one.'
assert num > 0, "Number of new nodes should be larger than one."
ntid = self.get_ntype_id(ntype)
# update graph idx
metagraph = self._graph.metagraph
......@@ -319,23 +400,32 @@ class DGLGraph(object):
relation_graphs = []
for c_etype in self.canonical_etypes:
# src or dst == ntype, update the relation graph
if self.get_ntype_id(c_etype[0]) == ntid or self.get_ntype_id(c_etype[2]) == ntid:
u, v = self.edges(form='uv', order='eid', etype=c_etype)
if (
self.get_ntype_id(c_etype[0]) == ntid
or self.get_ntype_id(c_etype[2]) == ntid
):
u, v = self.edges(form="uv", order="eid", etype=c_etype)
hgidx = heterograph_index.create_unitgraph_from_coo(
1 if c_etype[0] == c_etype[2] else 2,
self.number_of_nodes(c_etype[0]) + \
(num if self.get_ntype_id(c_etype[0]) == ntid else 0),
self.number_of_nodes(c_etype[2]) + \
(num if self.get_ntype_id(c_etype[2]) == ntid else 0),
self.number_of_nodes(c_etype[0])
+ (num if self.get_ntype_id(c_etype[0]) == ntid else 0),
self.number_of_nodes(c_etype[2])
+ (num if self.get_ntype_id(c_etype[2]) == ntid else 0),
u,
v,
['coo', 'csr', 'csc'])
["coo", "csr", "csc"],
)
relation_graphs.append(hgidx)
else:
# do nothing
relation_graphs.append(self._graph.get_relation_graph(self.get_etype_id(c_etype)))
relation_graphs.append(
self._graph.get_relation_graph(self.get_etype_id(c_etype))
)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64"))
metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
self._graph = hgidx
# update data frames
......@@ -452,26 +542,33 @@ class DGLGraph(object):
remove_edges
"""
# TODO(xiangsx): block do not support add_edges
u = utils.prepare_tensor(self, u, 'u')
v = utils.prepare_tensor(self, v, 'v')
u = utils.prepare_tensor(self, u, "u")
v = utils.prepare_tensor(self, v, "v")
if etype is None:
if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
raise DGLError(
"Edge type name must be specified if there are more than one "
"edge types."
)
# nothing changed
if len(u) == 0 or len(v) == 0:
return
assert len(u) == len(v) or len(u) == 1 or len(v) == 1, \
'The number of source nodes and the number of destination nodes should be same, ' \
'or either the number of source nodes or the number of destination nodes is 1.'
assert len(u) == len(v) or len(u) == 1 or len(v) == 1, (
"The number of source nodes and the number of destination nodes should be same, "
"or either the number of source nodes or the number of destination nodes is 1."
)
if len(u) == 1 and len(v) > 1:
u = F.full_1d(len(v), F.as_scalar(u), dtype=F.dtype(u), ctx=F.context(u))
u = F.full_1d(
len(v), F.as_scalar(u), dtype=F.dtype(u), ctx=F.context(u)
)
if len(v) == 1 and len(u) > 1:
v = F.full_1d(len(u), F.as_scalar(v), dtype=F.dtype(v), ctx=F.context(v))
v = F.full_1d(
len(u), F.as_scalar(v), dtype=F.dtype(v), ctx=F.context(v)
)
u_type, e_type, v_type = self.to_canonical_etype(etype)
# if end nodes of adding edges does not exists
......@@ -501,22 +598,28 @@ class DGLGraph(object):
for c_etype in self.canonical_etypes:
# the target edge type
if c_etype == (u_type, e_type, v_type):
old_u, old_v = self.edges(form='uv', order='eid', etype=c_etype)
old_u, old_v = self.edges(form="uv", order="eid", etype=c_etype)
hgidx = heterograph_index.create_unitgraph_from_coo(
1 if u_type == v_type else 2,
self.number_of_nodes(u_type),
self.number_of_nodes(v_type),
F.cat([old_u, u], dim=0),
F.cat([old_v, v], dim=0),
['coo', 'csr', 'csc'])
["coo", "csr", "csc"],
)
relation_graphs.append(hgidx)
else:
# do nothing
# Note: node range change has been handled in add_nodes()
relation_graphs.append(self._graph.get_relation_graph(self.get_etype_id(c_etype)))
relation_graphs.append(
self._graph.get_relation_graph(self.get_etype_id(c_etype))
)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64"))
metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
self._graph = hgidx
# handle data
......@@ -607,15 +710,19 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_edges
if etype is None:
if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one ' \
'edge types.')
eids = utils.prepare_tensor(self, eids, 'u')
raise DGLError(
"Edge type name must be specified if there are more than one "
"edge types."
)
eids = utils.prepare_tensor(self, eids, "u")
if len(eids) == 0:
# no edge to delete
return
assert self.number_of_edges(etype) > F.as_scalar(F.max(eids, dim=0)), \
'The input eid {} is out of the range [0:{})'.format(
F.as_scalar(F.max(eids, dim=0)), self.number_of_edges(etype))
assert self.number_of_edges(etype) > F.as_scalar(
F.max(eids, dim=0)
), "The input eid {} is out of the range [0:{})".format(
F.as_scalar(F.max(eids, dim=0)), self.number_of_edges(etype)
)
# edge_subgraph
edges = {}
......@@ -623,25 +730,36 @@ class DGLGraph(object):
for c_etype in self.canonical_etypes:
# the target edge type
if c_etype == (u_type, e_type, v_type):
origin_eids = self.edges(form='eid', order='eid', etype=c_etype)
origin_eids = self.edges(form="eid", order="eid", etype=c_etype)
edges[c_etype] = utils.compensate(eids, origin_eids)
else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
edges[c_etype] = self.edges(
form="eid", order="eid", etype=c_etype
)
# If the graph is batched, update batch_num_edges
batched = self._batch_num_edges is not None
if batched:
c_etype = (u_type, e_type, v_type)
one_hot_removed_edges = F.zeros((self.num_edges(c_etype),), F.float32, self.device)
one_hot_removed_edges = F.scatter_row(one_hot_removed_edges, eids,
F.full_1d(len(eids), 1., F.float32, self.device))
one_hot_removed_edges = F.zeros(
(self.num_edges(c_etype),), F.float32, self.device
)
one_hot_removed_edges = F.scatter_row(
one_hot_removed_edges,
eids,
F.full_1d(len(eids), 1.0, F.float32, self.device),
)
c_etype_batch_num_edges = self._batch_num_edges[c_etype]
batch_num_removed_edges = segment.segment_reduce(c_etype_batch_num_edges,
one_hot_removed_edges, reducer='sum')
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - \
F.astype(batch_num_removed_edges, F.int64)
sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids)
batch_num_removed_edges = segment.segment_reduce(
c_etype_batch_num_edges, one_hot_removed_edges, reducer="sum"
)
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - F.astype(
batch_num_removed_edges, F.int64
)
sub_g = self.edge_subgraph(
edges, relabel_nodes=False, store_ids=store_ids
)
self._graph = sub_g._graph
self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames
......@@ -733,16 +851,20 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_nodes
if ntype is None:
if self._graph.number_of_ntypes() != 1:
raise DGLError('Node type name must be specified if there are more than one ' \
'node types.')
raise DGLError(
"Node type name must be specified if there are more than one "
"node types."
)
nids = utils.prepare_tensor(self, nids, 'u')
nids = utils.prepare_tensor(self, nids, "u")
if len(nids) == 0:
# no node to delete
return
assert self.number_of_nodes(ntype) > F.as_scalar(F.max(nids, dim=0)), \
'The input nids {} is out of the range [0:{})'.format(
F.as_scalar(F.max(nids, dim=0)), self.number_of_nodes(ntype))
assert self.number_of_nodes(ntype) > F.as_scalar(
F.max(nids, dim=0)
), "The input nids {} is out of the range [0:{})".format(
F.as_scalar(F.max(nids, dim=0)), self.number_of_nodes(ntype)
)
ntid = self.get_ntype_id(ntype)
nodes = {}
......@@ -757,18 +879,28 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_nodes
batched = self._batch_num_nodes is not None
if batched:
one_hot_removed_nodes = F.zeros((self.num_nodes(target_ntype),),
F.float32, self.device)
one_hot_removed_nodes = F.scatter_row(one_hot_removed_nodes, nids,
F.full_1d(len(nids), 1., F.float32, self.device))
one_hot_removed_nodes = F.zeros(
(self.num_nodes(target_ntype),), F.float32, self.device
)
one_hot_removed_nodes = F.scatter_row(
one_hot_removed_nodes,
nids,
F.full_1d(len(nids), 1.0, F.float32, self.device),
)
c_ntype_batch_num_nodes = self._batch_num_nodes[target_ntype]
batch_num_removed_nodes = segment.segment_reduce(
c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer='sum')
self._batch_num_nodes[target_ntype] = c_ntype_batch_num_nodes - \
F.astype(batch_num_removed_nodes, F.int64)
c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer="sum"
)
self._batch_num_nodes[
target_ntype
] = c_ntype_batch_num_nodes - F.astype(
batch_num_removed_nodes, F.int64
)
# Record old num_edges to check later whether some edges were removed
old_num_edges = {c_etype: self._graph.number_of_edges(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes}
old_num_edges = {
c_etype: self._graph.number_of_edges(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes
}
# node_subgraph
# If batch_num_edges is to be updated, record the original edge IDs
......@@ -780,22 +912,36 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_edges
if batched:
canonical_etypes = [
c_etype for c_etype in self.canonical_etypes if
self._graph.number_of_edges(self.get_etype_id(c_etype)) != old_num_edges[c_etype]]
c_etype
for c_etype in self.canonical_etypes
if self._graph.number_of_edges(self.get_etype_id(c_etype))
!= old_num_edges[c_etype]
]
for c_etype in canonical_etypes:
if self._graph.number_of_edges(self.get_etype_id(c_etype)) == 0:
self._batch_num_edges[c_etype] = F.zeros(
(self.batch_size,), F.int64, self.device)
(self.batch_size,), F.int64, self.device
)
continue
one_hot_left_edges = F.zeros((old_num_edges[c_etype],), F.float32, self.device)
one_hot_left_edges = F.zeros(
(old_num_edges[c_etype],), F.float32, self.device
)
eids = self.edges[c_etype].data[EID]
one_hot_left_edges = F.scatter_row(one_hot_left_edges, eids,
F.full_1d(len(eids), 1., F.float32, self.device))
one_hot_left_edges = F.scatter_row(
one_hot_left_edges,
eids,
F.full_1d(len(eids), 1.0, F.float32, self.device),
)
batch_num_left_edges = segment.segment_reduce(
self._batch_num_edges[c_etype], one_hot_left_edges, reducer='sum')
self._batch_num_edges[c_etype] = F.astype(batch_num_left_edges, F.int64)
self._batch_num_edges[c_etype],
one_hot_left_edges,
reducer="sum",
)
self._batch_num_edges[c_etype] = F.astype(
batch_num_left_edges, F.int64
)
if batched and not store_ids:
for c_ntype in self.ntypes:
......@@ -810,7 +956,6 @@ class DGLGraph(object):
self._batch_num_nodes = None
self._batch_num_edges = None
#################################################################
# Metagraph query
#################################################################
......@@ -1080,7 +1225,9 @@ class DGLGraph(object):
nx_graph = self._graph.metagraph.to_networkx()
nx_metagraph = nx.MultiDiGraph()
for u_v in nx_graph.edges:
srctype, etype, dsttype = self.canonical_etypes[nx_graph.edges[u_v]['id']]
srctype, etype, dsttype = self.canonical_etypes[
nx_graph.edges[u_v]["id"]
]
nx_metagraph.add_edge(srctype, dsttype, etype)
return nx_metagraph
......@@ -1133,8 +1280,10 @@ class DGLGraph(object):
"""
if etype is None:
if len(self.etypes) != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
raise DGLError(
"Edge type name must be specified if there are more than one "
"edge types."
)
etype = self.etypes[0]
if isinstance(etype, tuple):
return etype
......@@ -1143,8 +1292,10 @@ class DGLGraph(object):
if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) == 0:
raise DGLError('Edge type "%s" is ambiguous. Please use canonical edge type '
'in the form of (srctype, etype, dsttype)' % etype)
raise DGLError(
'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)" % etype
)
return ret
def get_ntype_id(self, ntype):
......@@ -1164,19 +1315,23 @@ class DGLGraph(object):
"""
if self.is_unibipartite and ntype is not None:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if ntype.startswith('SRC/'):
if ntype.startswith("SRC/"):
return self.get_ntype_id_from_src(ntype[4:])
elif ntype.startswith('DST/'):
elif ntype.startswith("DST/"):
return self.get_ntype_id_from_dst(ntype[4:])
# If there is no prefix, fallback to normal lookup.
# Lookup both SRC and DST
if ntype is None:
if self.is_unibipartite or len(self._srctypes_invmap) != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
raise DGLError(
"Node type name must be specified if there are more than one "
"node types."
)
return 0
ntid = self._srctypes_invmap.get(ntype, self._dsttypes_invmap.get(ntype, None))
ntid = self._srctypes_invmap.get(
ntype, self._dsttypes_invmap.get(ntype, None)
)
if ntid is None:
raise DGLError('Node type "{}" does not exist.'.format(ntype))
return ntid
......@@ -1198,8 +1353,10 @@ class DGLGraph(object):
"""
if ntype is None:
if len(self._srctypes_invmap) != 1:
raise DGLError('SRC node type name must be specified if there are more than one '
'SRC node types.')
raise DGLError(
"SRC node type name must be specified if there are more than one "
"SRC node types."
)
return next(iter(self._srctypes_invmap.values()))
ntid = self._srctypes_invmap.get(ntype, None)
if ntid is None:
......@@ -1223,8 +1380,10 @@ class DGLGraph(object):
"""
if ntype is None:
if len(self._dsttypes_invmap) != 1:
raise DGLError('DST node type name must be specified if there are more than one '
'DST node types.')
raise DGLError(
"DST node type name must be specified if there are more than one "
"DST node types."
)
return next(iter(self._dsttypes_invmap.values()))
ntid = self._dsttypes_invmap.get(ntype, None)
if ntid is None:
......@@ -1248,8 +1407,10 @@ class DGLGraph(object):
"""
if etype is None:
if self._graph.number_of_etypes() != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
raise DGLError(
"Edge type name must be specified if there are more than one "
"edge types."
)
return 0
etid = self._etypes_invmap.get(self.to_canonical_etype(etype), None)
if etid is None:
......@@ -1346,17 +1507,23 @@ class DGLGraph(object):
tensor([2, 1])
"""
if ntype is not None and ntype not in self.ntypes:
raise DGLError('Expect ntype in {}, got {}'.format(self.ntypes, ntype))
raise DGLError(
"Expect ntype in {}, got {}".format(self.ntypes, ntype)
)
if self._batch_num_nodes is None:
self._batch_num_nodes = {}
for ty in self.ntypes:
bnn = F.copy_to(F.tensor([self.number_of_nodes(ty)], F.int64), self.device)
bnn = F.copy_to(
F.tensor([self.number_of_nodes(ty)], F.int64), self.device
)
self._batch_num_nodes[ty] = bnn
if ntype is None:
if len(self.ntypes) != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
raise DGLError(
"Node type name must be specified if there are more than one "
"node types."
)
ntype = self.ntypes[0]
return self._batch_num_nodes[ntype]
......@@ -1440,8 +1607,10 @@ class DGLGraph(object):
"""
if not isinstance(val, Mapping):
if len(self.ntypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple node types.')
val = {self.ntypes[0] : val}
raise DGLError(
"Must provide a dictionary when there are multiple node types."
)
val = {self.ntypes[0]: val}
self._batch_num_nodes = val
def batch_num_edges(self, etype=None):
......@@ -1494,12 +1663,16 @@ class DGLGraph(object):
if self._batch_num_edges is None:
self._batch_num_edges = {}
for ty in self.canonical_etypes:
bne = F.copy_to(F.tensor([self.number_of_edges(ty)], F.int64), self.device)
bne = F.copy_to(
F.tensor([self.number_of_edges(ty)], F.int64), self.device
)
self._batch_num_edges[ty] = bne
if etype is None:
if len(self.etypes) != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
raise DGLError(
"Edge type name must be specified if there are more than one "
"edge types."
)
etype = self.canonical_etypes[0]
else:
etype = self.to_canonical_etype(etype)
......@@ -1585,8 +1758,10 @@ class DGLGraph(object):
"""
if not isinstance(val, Mapping):
if len(self.etypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple edge types.')
val = {self.canonical_etypes[0] : val}
raise DGLError(
"Must provide a dictionary when there are multiple edge types."
)
val = {self.canonical_etypes[0]: val}
self._batch_num_edges = val
#################################################################
......@@ -2130,10 +2305,14 @@ class DGLGraph(object):
def _find_etypes(self, key):
etypes = [
i for i, (srctype, etype, dsttype) in enumerate(self._canonical_etypes) if
(key[0] == SLICE_FULL or key[0] == srctype) and
(key[1] == SLICE_FULL or key[1] == etype) and
(key[2] == SLICE_FULL or key[2] == dsttype)]
i
for i, (srctype, etype, dsttype) in enumerate(
self._canonical_etypes
)
if (key[0] == SLICE_FULL or key[0] == srctype)
and (key[1] == SLICE_FULL or key[1] == etype)
and (key[2] == SLICE_FULL or key[2] == dsttype)
]
return etypes
def __getitem__(self, key):
......@@ -2215,9 +2394,11 @@ class DGLGraph(object):
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
"to get view of one relation type. Use : to slice multiple types (e.g. " +\
"G['srctype', :, 'dsttype'])."
err_msg = (
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
+ "to get view of one relation type. Use : to slice multiple types (e.g. "
+ "G['srctype', :, 'dsttype'])."
)
orig_key = key
if not isinstance(key, tuple):
......@@ -2229,7 +2410,11 @@ class DGLGraph(object):
etypes = self._find_etypes(key)
if len(etypes) == 0:
raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
raise DGLError(
'Invalid key "{}". Must be one of the edge types.'.format(
orig_key
)
)
if len(etypes) == 1:
# no ambiguity: return the unitgraph itself
......@@ -2248,7 +2433,9 @@ class DGLGraph(object):
new_etypes = [etype]
new_eframes = [self._edge_frames[etid]]
return self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
return self.__class__(
new_g, new_ntypes, new_etypes, new_nframes, new_eframes
)
else:
flat = self._graph.flatten_relations(etypes)
new_g = flat.graph
......@@ -2262,7 +2449,8 @@ class DGLGraph(object):
new_ntypes.append(combine_names(self.ntypes, dtids))
new_nframes = [
combine_frames(self._node_frames, stids),
combine_frames(self._node_frames, dtids)]
combine_frames(self._node_frames, dtids),
]
else:
assert np.array_equal(stids, dtids)
new_nframes = [combine_frames(self._node_frames, stids)]
......@@ -2270,16 +2458,28 @@ class DGLGraph(object):
new_eframes = [combine_frames(self._edge_frames, etids)]
# create new heterograph
new_hg = self.__class__(new_g, new_ntypes, new_etypes, new_nframes, new_eframes)
new_hg = self.__class__(
new_g, new_ntypes, new_etypes, new_nframes, new_eframes
)
src = new_ntypes[0]
dst = new_ntypes[1] if new_g.number_of_ntypes() == 2 else src
# put the parent node/edge type and IDs
new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_srctype)
new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_srcid)
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_dsttype)
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(flat.induced_dstid)
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(flat.induced_etype)
new_hg.nodes[src].data[NTYPE] = F.zerocopy_from_dgl_ndarray(
flat.induced_srctype
)
new_hg.nodes[src].data[NID] = F.zerocopy_from_dgl_ndarray(
flat.induced_srcid
)
new_hg.nodes[dst].data[NTYPE] = F.zerocopy_from_dgl_ndarray(
flat.induced_dsttype
)
new_hg.nodes[dst].data[NID] = F.zerocopy_from_dgl_ndarray(
flat.induced_dstid
)
new_hg.edata[ETYPE] = F.zerocopy_from_dgl_ndarray(
flat.induced_etype
)
new_hg.edata[EID] = F.zerocopy_from_dgl_ndarray(flat.induced_eid)
return new_hg
......@@ -2331,7 +2531,12 @@ class DGLGraph(object):
12
"""
if ntype is None:
return sum([self._graph.number_of_nodes(ntid) for ntid in range(len(self.ntypes))])
return sum(
[
self._graph.number_of_nodes(ntid)
for ntid in range(len(self.ntypes))
]
)
else:
return self._graph.number_of_nodes(self.get_ntype_id(ntype))
......@@ -2396,10 +2601,16 @@ class DGLGraph(object):
7
"""
if ntype is None:
return sum([self._graph.number_of_nodes(self.get_ntype_id_from_src(nty))
for nty in self.srctypes])
return sum(
[
self._graph.number_of_nodes(self.get_ntype_id_from_src(nty))
for nty in self.srctypes
]
)
else:
return self._graph.number_of_nodes(self.get_ntype_id_from_src(ntype))
return self._graph.number_of_nodes(
self.get_ntype_id_from_src(ntype)
)
def number_of_dst_nodes(self, ntype=None):
"""Alias of :func:`num_dst_nodes`"""
......@@ -2462,10 +2673,16 @@ class DGLGraph(object):
12
"""
if ntype is None:
return sum([self._graph.number_of_nodes(self.get_ntype_id_from_dst(nty))
for nty in self.dsttypes])
return sum(
[
self._graph.number_of_nodes(self.get_ntype_id_from_dst(nty))
for nty in self.dsttypes
]
)
else:
return self._graph.number_of_nodes(self.get_ntype_id_from_dst(ntype))
return self._graph.number_of_nodes(
self.get_ntype_id_from_dst(ntype)
)
def number_of_edges(self, etype=None):
"""Alias of :func:`num_edges`"""
......@@ -2522,8 +2739,12 @@ class DGLGraph(object):
3
"""
if etype is None:
return sum([self._graph.number_of_edges(etid)
for etid in range(len(self.canonical_etypes))])
return sum(
[
self._graph.number_of_edges(etid)
for etid in range(len(self.canonical_etypes))
]
)
else:
return self._graph.number_of_edges(self.get_etype_id(etype))
......@@ -2708,10 +2929,11 @@ class DGLGraph(object):
tensor([False, True, True])
"""
vid_tensor = utils.prepare_tensor(self, vid, "vid")
if len(vid_tensor) > 0 and F.as_scalar(F.min(vid_tensor, 0)) < 0 < len(vid_tensor):
raise DGLError('All IDs must be non-negative integers.')
ret = self._graph.has_nodes(
self.get_ntype_id(ntype), vid_tensor)
if len(vid_tensor) > 0 and F.as_scalar(F.min(vid_tensor, 0)) < 0 < len(
vid_tensor
):
raise DGLError("All IDs must be non-negative integers.")
ret = self._graph.has_nodes(self.get_ntype_id(ntype), vid_tensor)
if isinstance(vid, numbers.Integral):
return bool(F.as_scalar(ret))
else:
......@@ -2793,15 +3015,19 @@ class DGLGraph(object):
tensor([True, True])
"""
srctype, _, dsttype = self.to_canonical_etype(etype)
u_tensor = utils.prepare_tensor(self, u, 'u')
if F.as_scalar(F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)) != len(u_tensor):
raise DGLError('u contains invalid node IDs')
v_tensor = utils.prepare_tensor(self, v, 'v')
if F.as_scalar(F.sum(self.has_nodes(v_tensor, ntype=dsttype), dim=0)) != len(v_tensor):
raise DGLError('v contains invalid node IDs')
u_tensor = utils.prepare_tensor(self, u, "u")
if F.as_scalar(
F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)
) != len(u_tensor):
raise DGLError("u contains invalid node IDs")
v_tensor = utils.prepare_tensor(self, v, "v")
if F.as_scalar(
F.sum(self.has_nodes(v_tensor, ntype=dsttype), dim=0)
) != len(v_tensor):
raise DGLError("v contains invalid node IDs")
ret = self._graph.has_edges_between(
self.get_etype_id(etype),
u_tensor, v_tensor)
self.get_etype_id(etype), u_tensor, v_tensor
)
if isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral):
return bool(F.as_scalar(ret))
else:
......@@ -2863,7 +3089,7 @@ class DGLGraph(object):
successors
"""
if not self.has_nodes(v, self.to_canonical_etype(etype)[-1]):
raise DGLError('Non-existing node ID {}'.format(v))
raise DGLError("Non-existing node ID {}".format(v))
return self._graph.predecessors(self.get_etype_id(etype), v)
def successors(self, v, etype=None):
......@@ -2921,7 +3147,7 @@ class DGLGraph(object):
predecessors
"""
if not self.has_nodes(v, self.to_canonical_etype(etype)[0]):
raise DGLError('Non-existing node ID {}'.format(v))
raise DGLError("Non-existing node ID {}".format(v))
return self._graph.successors(self.get_etype_id(etype), v)
def edge_ids(self, u, v, return_uv=False, etype=None):
......@@ -3018,14 +3244,20 @@ class DGLGraph(object):
... etype=('user', 'follows', 'game'))
tensor([1, 2])
"""
is_int = isinstance(u, numbers.Integral) and isinstance(v, numbers.Integral)
is_int = isinstance(u, numbers.Integral) and isinstance(
v, numbers.Integral
)
srctype, _, dsttype = self.to_canonical_etype(etype)
u = utils.prepare_tensor(self, u, 'u')
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u):
raise DGLError('u contains invalid node IDs')
v = utils.prepare_tensor(self, v, 'v')
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v):
raise DGLError('v contains invalid node IDs')
u = utils.prepare_tensor(self, u, "u")
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(
u
):
raise DGLError("u contains invalid node IDs")
v = utils.prepare_tensor(self, v, "v")
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(
v
):
raise DGLError("v contains invalid node IDs")
if return_uv:
return self._graph.edge_ids_all(self.get_etype_id(etype), u, v)
......@@ -3035,9 +3267,13 @@ class DGLGraph(object):
if F.as_scalar(F.sum(is_neg_one, 0)):
# Raise error since some (u, v) pair is not a valid edge.
idx = F.nonzero_1d(is_neg_one)
raise DGLError("Error: (%d, %d) does not form a valid edge." % (
F.as_scalar(F.gather_row(u, idx)),
F.as_scalar(F.gather_row(v, idx))))
raise DGLError(
"Error: (%d, %d) does not form a valid edge."
% (
F.as_scalar(F.gather_row(u, idx)),
F.as_scalar(F.gather_row(v, idx)),
)
)
return F.as_scalar(eid) if is_int else eid
def find_edges(self, eid, etype=None):
......@@ -3096,14 +3332,14 @@ class DGLGraph(object):
>>> hg.find_edges(torch.tensor([1, 0]), 'plays')
(tensor([4, 3]), tensor([6, 5]))
"""
eid = utils.prepare_tensor(self, eid, 'eid')
eid = utils.prepare_tensor(self, eid, "eid")
if len(eid) > 0:
min_eid = F.as_scalar(F.min(eid, 0))
if min_eid < 0:
raise DGLError('Invalid edge ID {:d}'.format(min_eid))
raise DGLError("Invalid edge ID {:d}".format(min_eid))
max_eid = F.as_scalar(F.max(eid, 0))
if max_eid >= self.num_edges(etype):
raise DGLError('Invalid edge ID {:d}'.format(max_eid))
raise DGLError("Invalid edge ID {:d}".format(max_eid))
if len(eid) == 0:
empty = F.copy_to(F.tensor([], self.idtype), self.device)
......@@ -3111,7 +3347,7 @@ class DGLGraph(object):
src, dst, _ = self._graph.find_edges(self.get_etype_id(etype), eid)
return src, dst
def in_edges(self, v, form='uv', etype=None):
def in_edges(self, v, form="uv", etype=None):
"""Return the incoming edges of the given nodes.
Parameters
......@@ -3184,18 +3420,20 @@ class DGLGraph(object):
edges
out_edges
"""
v = utils.prepare_tensor(self, v, 'v')
v = utils.prepare_tensor(self, v, "v")
src, dst, eid = self._graph.in_edges(self.get_etype_id(etype), v)
if form == 'all':
if form == "all":
return src, dst, eid
elif form == 'uv':
elif form == "uv":
return src, dst
elif form == 'eid':
elif form == "eid":
return eid
else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form))
raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def out_edges(self, u, form='uv', etype=None):
def out_edges(self, u, form="uv", etype=None):
"""Return the outgoing edges of the given nodes.
Parameters
......@@ -3268,21 +3506,25 @@ class DGLGraph(object):
edges
in_edges
"""
u = utils.prepare_tensor(self, u, 'u')
u = utils.prepare_tensor(self, u, "u")
srctype, _, _ = self.to_canonical_etype(etype)
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u):
raise DGLError('u contains invalid node IDs')
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(
u
):
raise DGLError("u contains invalid node IDs")
src, dst, eid = self._graph.out_edges(self.get_etype_id(etype), u)
if form == 'all':
if form == "all":
return src, dst, eid
elif form == 'uv':
elif form == "uv":
return src, dst
elif form == 'eid':
elif form == "eid":
return eid
else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form))
raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def all_edges(self, form='uv', order='eid', etype=None):
def all_edges(self, form="uv", order="eid", etype=None):
"""Return all edges with the specified edge type.
Parameters
......@@ -3353,14 +3595,16 @@ class DGLGraph(object):
out_edges
"""
src, dst, eid = self._graph.edges(self.get_etype_id(etype), order)
if form == 'all':
if form == "all":
return src, dst, eid
elif form == 'uv':
elif form == "uv":
return src, dst
elif form == 'eid':
elif form == "eid":
return eid
else:
raise DGLError('Invalid form: {}. Must be "all", "uv" or "eid".'.format(form))
raise DGLError(
'Invalid form: {}. Must be "all", "uv" or "eid".'.format(form)
)
def in_degrees(self, v=ALL, etype=None):
"""Return the in-degree(s) of the given nodes.
......@@ -3431,7 +3675,7 @@ class DGLGraph(object):
etid = self.get_etype_id(etype)
if is_all(v):
v = self.dstnodes(dsttype)
v_tensor = utils.prepare_tensor(self, v, 'v')
v_tensor = utils.prepare_tensor(self, v, "v")
deg = self._graph.in_degrees(etid, v_tensor)
if isinstance(v, numbers.Integral):
return F.as_scalar(deg)
......@@ -3507,16 +3751,20 @@ class DGLGraph(object):
etid = self.get_etype_id(etype)
if is_all(u):
u = self.srcnodes(srctype)
u_tensor = utils.prepare_tensor(self, u, 'u')
if F.as_scalar(F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)) != len(u_tensor):
raise DGLError('u contains invalid node IDs')
deg = self._graph.out_degrees(etid, utils.prepare_tensor(self, u, 'u'))
u_tensor = utils.prepare_tensor(self, u, "u")
if F.as_scalar(
F.sum(self.has_nodes(u_tensor, ntype=srctype), dim=0)
) != len(u_tensor):
raise DGLError("u contains invalid node IDs")
deg = self._graph.out_degrees(etid, utils.prepare_tensor(self, u, "u"))
if isinstance(u, numbers.Integral):
return F.as_scalar(deg)
else:
return deg
def adjacency_matrix(self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None):
def adjacency_matrix(
self, transpose=False, ctx=F.cpu(), scipy_fmt=None, etype=None
):
"""Alias of :meth:`adj`"""
return self.adj(transpose, ctx, scipy_fmt, etype)
......@@ -3586,7 +3834,9 @@ class DGLGraph(object):
if scipy_fmt is None:
return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
else:
return self._graph.adjacency_matrix_scipy(etid, transpose, scipy_fmt, False)
return self._graph.adjacency_matrix_scipy(
etid, transpose, scipy_fmt, False
)
def adj_sparse(self, fmt, etype=None):
"""Return the adjacency matrix of edges of the given edge type as tensors of
......@@ -3629,9 +3879,9 @@ class DGLGraph(object):
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))
"""
etid = self.get_etype_id(etype)
if fmt == 'csc':
if fmt == "csc":
# The first two elements are number of rows and columns
return self._graph.adjacency_matrix_tensors(etid, True, 'csr')[2:]
return self._graph.adjacency_matrix_tensors(etid, True, "csr")[2:]
else:
return self._graph.adjacency_matrix_tensors(etid, False, fmt)[2:]
......@@ -4024,26 +4274,36 @@ class DGLGraph(object):
if is_all(u):
num_nodes = self._graph.number_of_nodes(ntid)
else:
u = utils.prepare_tensor(self, u, 'u')
u = utils.prepare_tensor(self, u, "u")
num_nodes = len(u)
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_nodes:
raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
raise DGLError(
"Expect number of features to match number of nodes (len(u))."
" Got %d and %d instead." % (nfeats, num_nodes)
)
if F.context(val) != self.device:
raise DGLError('Cannot assign node feature "{}" on device {} to a graph on'
' device {}. Call DGLGraph.to() to copy the graph to the'
' same device.'.format(key, F.context(val), self.device))
raise DGLError(
'Cannot assign node feature "{}" on device {} to a graph on'
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device.".format(key, F.context(val), self.device)
)
# To prevent users from doing things like:
#
# g.pin_memory_()
# g.ndata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if self.is_pinned() and F.context(val) == 'cpu' and not F.is_pinned(val):
raise DGLError('Pinned graph requires the node data to be pinned as well. '
'Please pin the node data before assignment.')
if (
self.is_pinned()
and F.context(val) == "cpu"
and not F.is_pinned(val)
):
raise DGLError(
"Pinned graph requires the node data to be pinned as well. "
"Please pin the node data before assignment."
)
if is_all(u):
self._node_frames[ntid].update(data)
......@@ -4070,7 +4330,7 @@ class DGLGraph(object):
if is_all(u):
return self._node_frames[ntid]
else:
u = utils.prepare_tensor(self, u, 'u')
u = utils.prepare_tensor(self, u, "u")
return self._node_frames[ntid].subframe(u)
def _pop_n_repr(self, ntid, key):
......@@ -4116,12 +4376,14 @@ class DGLGraph(object):
"""
# parse argument
if not is_all(edges):
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
# sanity check
if not utils.is_dict_like(data):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(data))
raise DGLError(
"Expect dictionary type for feature data."
' Got "%s" instead.' % type(data)
)
if is_all(edges):
num_edges = self._graph.number_of_edges(etid)
......@@ -4130,21 +4392,31 @@ class DGLGraph(object):
for key, val in data.items():
nfeats = F.shape(val)[0]
if nfeats != num_edges:
raise DGLError('Expect number of features to match number of edges.'
' Got %d and %d instead.' % (nfeats, num_edges))
raise DGLError(
"Expect number of features to match number of edges."
" Got %d and %d instead." % (nfeats, num_edges)
)
if F.context(val) != self.device:
raise DGLError('Cannot assign edge feature "{}" on device {} to a graph on'
' device {}. Call DGLGraph.to() to copy the graph to the'
' same device.'.format(key, F.context(val), self.device))
raise DGLError(
'Cannot assign edge feature "{}" on device {} to a graph on'
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device.".format(key, F.context(val), self.device)
)
# To prevent users from doing things like:
#
# g.pin_memory_()
# g.edata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if self.is_pinned() and F.context(val) == 'cpu' and not F.is_pinned(val):
raise DGLError('Pinned graph requires the edge data to be pinned as well. '
'Please pin the edge data before assignment.')
if (
self.is_pinned()
and F.context(val) == "cpu"
and not F.is_pinned(val)
):
raise DGLError(
"Pinned graph requires the edge data to be pinned as well. "
"Please pin the edge data before assignment."
)
# set
if is_all(edges):
......@@ -4172,7 +4444,7 @@ class DGLGraph(object):
if is_all(edges):
return self._edge_frames[etid]
else:
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
return self._edge_frames[etid].subframe(eid)
def _pop_e_repr(self, etid, key):
......@@ -4256,7 +4528,7 @@ class DGLGraph(object):
if is_all(v):
v_id = self.nodes(ntype)
else:
v_id = utils.prepare_tensor(self, v, 'v')
v_id = utils.prepare_tensor(self, v, "v")
ndata = core.invoke_node_udf(self, v_id, ntype, func, orig_nid=v_id)
self._set_n_repr(ntid, v, ndata)
......@@ -4346,16 +4618,18 @@ class DGLGraph(object):
etid = self.get_etype_id(etype)
etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype]
else: # heterogeneous graph with number of relation types > 1
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(func):
raise DGLError("User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead.")
raise DGLError(
"User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead."
)
g = self
if is_all(edges):
eid = ALL
else:
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
if core.is_builtin(func):
if not is_all(eid):
g = g.edge_subgraph(eid, relabel_nodes=False)
......@@ -4375,12 +4649,9 @@ class DGLGraph(object):
edata_tensor[key] = out_tensor_tuples[etid]
self._set_e_repr(etid, eid, edata_tensor)
def send_and_recv(self,
edges,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
def send_and_recv(
self, edges, message_func, reduce_func, apply_node_func=None, etype=None
):
"""Send messages along the specified edges and reduce them on
the destination nodes to update their features.
......@@ -4493,7 +4764,7 @@ class DGLGraph(object):
_, dtid = self._graph.metagraph.find_edge(etid)
etype = self.canonical_etypes[etid]
# edge IDs
eid = utils.parse_edges_arg_to_eid(self, edges, etid, 'edges')
eid = utils.parse_edges_arg_to_eid(self, edges, etid, "edges")
if len(eid) == 0:
# no computation
return
......@@ -4502,15 +4773,13 @@ class DGLGraph(object):
g = self if etype is None else self[etype]
compute_graph, _, dstnodes, _ = _create_compute_graph(g, u, v, eid)
ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func)
compute_graph, message_func, reduce_func, apply_node_func
)
self._set_n_repr(dtid, dstnodes, ndata)
def pull(self,
v,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
def pull(
self, v, message_func, reduce_func, apply_node_func=None, etype=None
):
"""Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features.
......@@ -4588,7 +4857,7 @@ class DGLGraph(object):
[1.],
[1.]])
"""
v = utils.prepare_tensor(self, v, 'v')
v = utils.prepare_tensor(self, v, "v")
if len(v) == 0:
# no computation
return
......@@ -4597,18 +4866,18 @@ class DGLGraph(object):
etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype]
# call message passing on subgraph
src, dst, eid = g.in_edges(v, form='all')
compute_graph, _, dstnodes, _ = _create_compute_graph(g, src, dst, eid, v)
src, dst, eid = g.in_edges(v, form="all")
compute_graph, _, dstnodes, _ = _create_compute_graph(
g, src, dst, eid, v
)
ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func)
compute_graph, message_func, reduce_func, apply_node_func
)
self._set_n_repr(dtid, dstnodes, ndata)
def push(self,
u,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
def push(
self, u, message_func, reduce_func, apply_node_func=None, etype=None
):
"""Send message from the specified node(s) to their successors
along the specified edge type and update their node features.
......@@ -4679,14 +4948,14 @@ class DGLGraph(object):
[0.],
[0.]])
"""
edges = self.out_edges(u, form='eid', etype=etype)
self.send_and_recv(edges, message_func, reduce_func, apply_node_func, etype=etype)
def update_all(self,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
edges = self.out_edges(u, form="eid", etype=etype)
self.send_and_recv(
edges, message_func, reduce_func, apply_node_func, etype=etype
)
def update_all(
self, message_func, reduce_func, apply_node_func=None, etype=None
):
"""Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type.
......@@ -4778,23 +5047,37 @@ class DGLGraph(object):
etype = self.canonical_etypes[etid]
_, dtid = self._graph.metagraph.find_edge(etid)
g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
ndata = core.message_passing(
g, message_func, reduce_func, apply_node_func
)
if (
core.is_builtin(reduce_func)
and reduce_func.name in ["min", "max"]
and ndata
):
# Replace infinity with zero for isolated nodes
key = list(ndata.keys())[0]
ndata[key] = F.replace_inf_with_zero(ndata[key])
self._set_n_repr(dtid, ALL, ndata)
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(message_func) or not core.is_builtin(reduce_func):
raise DGLError("User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['mean']:
raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead.")
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(message_func) or not core.is_builtin(
reduce_func
):
raise DGLError(
"User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead."
)
if reduce_func.name in ["mean"]:
raise NotImplementedError(
"Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead."
)
g = self
all_out = core.message_passing(g, message_func, reduce_func, apply_node_func)
all_out = core.message_passing(
g, message_func, reduce_func, apply_node_func
)
key = list(all_out.keys())[0]
out_tensor_tuples = all_out[key]
......@@ -4802,7 +5085,10 @@ class DGLGraph(object):
for _, _, dsttype in g.canonical_etypes:
dtid = g.get_ntype_id(dsttype)
dst_tensor[key] = out_tensor_tuples[dtid]
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max']:
if core.is_builtin(reduce_func) and reduce_func.name in [
"min",
"max",
]:
dst_tensor[key] = F.replace_inf_with_zero(dst_tensor[key])
self._node_frames[dtid].update(dst_tensor)
......@@ -4902,36 +5188,44 @@ class DGLGraph(object):
_, dtid = self._graph.metagraph.find_edge(etid)
args = pad_tuple(args, 3)
if args is None:
raise DGLError('Invalid arguments for edge type "{}". Should be '
'(msg_func, reduce_func, [apply_node_func])'.format(etype))
raise DGLError(
'Invalid arguments for edge type "{}". Should be '
"(msg_func, reduce_func, [apply_node_func])".format(etype)
)
mfunc, rfunc, afunc = args
g = self if etype is None else self[etype]
all_out[dtid].append(core.message_passing(g, mfunc, rfunc, afunc))
merge_order[dtid].append(etid) # use edge type id as merge order hint
merge_order[dtid].append(
etid
) # use edge type id as merge order hint
for dtid, frames in all_out.items():
# merge by cross_reducer
out = reduce_dict_data(frames, cross_reducer, merge_order[dtid])
# Replace infinity with zero for isolated nodes when reducer is min/max
if core.is_builtin(rfunc) and rfunc.name in ['min', 'max']:
if core.is_builtin(rfunc) and rfunc.name in ["min", "max"]:
key = list(out.keys())[0]
out[key] = F.replace_inf_with_zero(out[key]) if out[key] is not None else None
out[key] = (
F.replace_inf_with_zero(out[key])
if out[key] is not None
else None
)
self._node_frames[dtid].update(out)
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])
#################################################################
# Message propagation
#################################################################
def prop_nodes(self,
nodes_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
def prop_nodes(
self,
nodes_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
):
"""Propagate messages using graph traversal by sequentially triggering
:func:`pull()` on nodes.
......@@ -4987,14 +5281,22 @@ class DGLGraph(object):
prop_edges
"""
for node_frontier in nodes_generator:
self.pull(node_frontier, message_func, reduce_func, apply_node_func, etype=etype)
def prop_edges(self,
edges_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None):
self.pull(
node_frontier,
message_func,
reduce_func,
apply_node_func,
etype=etype,
)
def prop_edges(
self,
edges_generator,
message_func,
reduce_func,
apply_node_func=None,
etype=None,
):
"""Propagate messages using graph traversal by sequentially triggering
:func:`send_and_recv()` on edges.
......@@ -5051,8 +5353,13 @@ class DGLGraph(object):
prop_nodes
"""
for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, message_func, reduce_func,
apply_node_func, etype=etype)
self.send_and_recv(
edge_frontier,
message_func,
reduce_func,
apply_node_func,
etype=etype,
)
#################################################################
# Misc
......@@ -5127,14 +5434,16 @@ class DGLGraph(object):
"""
if is_all(nodes):
nodes = self.nodes(ntype)
v = utils.prepare_tensor(self, nodes, 'nodes')
v = utils.prepare_tensor(self, nodes, "nodes")
if F.as_scalar(F.sum(self.has_nodes(v, ntype=ntype), dim=0)) != len(v):
raise DGLError('v contains invalid node IDs')
raise DGLError("v contains invalid node IDs")
with self.local_scope():
self.apply_nodes(lambda nbatch: {'_mask' : predicate(nbatch)}, nodes, ntype)
self.apply_nodes(
lambda nbatch: {"_mask": predicate(nbatch)}, nodes, ntype
)
ntype = self.ntypes[0] if ntype is None else ntype
mask = self.nodes[ntype].data['_mask']
mask = self.nodes[ntype].data["_mask"]
if is_all(nodes):
return F.nonzero_1d(mask)
else:
......@@ -5221,34 +5530,40 @@ class DGLGraph(object):
elif isinstance(edges, tuple):
u, v = edges
srctype, _, dsttype = self.to_canonical_etype(etype)
u = utils.prepare_tensor(self, u, 'u')
if F.as_scalar(F.sum(self.has_nodes(u, ntype=srctype), dim=0)) != len(u):
raise DGLError('edges[0] contains invalid node IDs')
v = utils.prepare_tensor(self, v, 'v')
if F.as_scalar(F.sum(self.has_nodes(v, ntype=dsttype), dim=0)) != len(v):
raise DGLError('edges[1] contains invalid node IDs')
u = utils.prepare_tensor(self, u, "u")
if F.as_scalar(
F.sum(self.has_nodes(u, ntype=srctype), dim=0)
) != len(u):
raise DGLError("edges[0] contains invalid node IDs")
v = utils.prepare_tensor(self, v, "v")
if F.as_scalar(
F.sum(self.has_nodes(v, ntype=dsttype), dim=0)
) != len(v):
raise DGLError("edges[1] contains invalid node IDs")
elif isinstance(edges, Iterable) or F.is_tensor(edges):
edges = utils.prepare_tensor(self, edges, 'edges')
edges = utils.prepare_tensor(self, edges, "edges")
min_eid = F.as_scalar(F.min(edges, 0))
if len(edges) > 0 > min_eid:
raise DGLError('Invalid edge ID {:d}'.format(min_eid))
raise DGLError("Invalid edge ID {:d}".format(min_eid))
max_eid = F.as_scalar(F.max(edges, 0))
if len(edges) > 0 and max_eid >= self.num_edges(etype):
raise DGLError('Invalid edge ID {:d}'.format(max_eid))
raise DGLError("Invalid edge ID {:d}".format(max_eid))
else:
raise ValueError('Unsupported type of edges:', type(edges))
raise ValueError("Unsupported type of edges:", type(edges))
with self.local_scope():
self.apply_edges(lambda ebatch: {'_mask' : predicate(ebatch)}, edges, etype)
self.apply_edges(
lambda ebatch: {"_mask": predicate(ebatch)}, edges, etype
)
etype = self.canonical_etypes[0] if etype is None else etype
mask = self.edges[etype].data['_mask']
mask = self.edges[etype].data["_mask"]
if is_all(edges):
return F.nonzero_1d(mask)
else:
if isinstance(edges, tuple):
e = self.edge_ids(edges[0], edges[1], etype=etype)
else:
e = utils.prepare_tensor(self, edges, 'edges')
e = utils.prepare_tensor(self, edges, "edges")
return F.boolean_mask(e, F.gather_row(mask, e))
@property
......@@ -5347,12 +5662,16 @@ class DGLGraph(object):
# 2. Copy misc info
if self._batch_num_nodes is not None:
new_bnn = {k : F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_nodes.items()}
new_bnn = {
k: F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_nodes.items()
}
ret._batch_num_nodes = new_bnn
if self._batch_num_edges is not None:
new_bne = {k : F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_edges.items()}
new_bne = {
k: F.copy_to(num, device, **kwargs)
for k, num in self._batch_num_edges.items()
}
ret._batch_num_edges = new_bne
return ret
......@@ -5432,8 +5751,10 @@ class DGLGraph(object):
tensor([0, 1, 1])
"""
if not self._graph.is_pinned():
if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.")
if F.device_type(self.device) != "cpu":
raise DGLError(
"The graph structure must be on CPU to be pinned."
)
self._graph.pin_memory_()
for frame in itertools.chain(self._node_frames, self._edge_frames):
for col in frame._columns.values():
......@@ -5484,9 +5805,9 @@ class DGLGraph(object):
DGLGraph
self.
"""
if F.get_preferred_backend() != 'pytorch':
if F.get_preferred_backend() != "pytorch":
raise DGLError("record_stream only support the PyTorch backend.")
if F.device_type(self.device) != 'cuda':
if F.device_type(self.device) != "cuda":
raise DGLError("The graph must be on GPU to be recorded.")
self._graph.record_stream(stream)
for frame in itertools.chain(self._node_frames, self._edge_frames):
......@@ -5510,15 +5831,24 @@ class DGLGraph(object):
# Clone the graph structure
meta_edges = []
for s_ntype, _, d_ntype in self.canonical_etypes:
meta_edges.append((self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype)))
meta_edges.append(
(self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype))
)
metagraph = graph_index.from_edge_list(meta_edges, True)
# rebuild graph idx
num_nodes_per_type = [self.number_of_nodes(c_ntype) for c_ntype in self.ntypes]
relation_graphs = [self._graph.get_relation_graph(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes]
num_nodes_per_type = [
self.number_of_nodes(c_ntype) for c_ntype in self.ntypes
]
relation_graphs = [
self._graph.get_relation_graph(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes
]
ret._graph = heterograph_index.create_heterograph_from_relations(
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64"))
metagraph,
relation_graphs,
utils.toindex(num_nodes_per_type, "int64"),
)
# Clone the frames
ret._node_frames = [fr.clone() for fr in self._node_frames]
......@@ -5819,7 +6149,7 @@ class DGLGraph(object):
return ret
# TODO: Formats should not be specified, just saving all the materialized formats
def shared_memory(self, name, formats=('coo', 'csr', 'csc')):
def shared_memory(self, name, formats=("coo", "csr", "csc")):
"""Return a copy of this graph in shared memory, without node data or edge data.
It moves the graph index to shared memory and returns a DGLGraph object which
......@@ -5843,11 +6173,16 @@ class DGLGraph(object):
if isinstance(formats, str):
formats = [formats]
for fmt in formats:
assert fmt in ("coo", "csr", "csc"), '{} is not coo, csr or csc'.format(fmt)
gidx = self._graph.shared_memory(name, self.ntypes, self.etypes, formats)
assert fmt in (
"coo",
"csr",
"csc",
), "{} is not coo, csr or csc".format(fmt)
gidx = self._graph.shared_memory(
name, self.ntypes, self.etypes, formats
)
return DGLGraph(gidx, self.ntypes, self.etypes)
def long(self):
"""Cast the graph to one with idtype int64
......@@ -5948,10 +6283,12 @@ class DGLGraph(object):
"""
return self.astype(F.int32)
############################################################
# Internal APIs
############################################################
def make_canonical_etypes(etypes, ntypes, metagraph):
"""Internal function to convert etype name to (srctype, etype, dsttype)
......@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
"""
# sanity check
if len(etypes) != metagraph.number_of_edges():
raise DGLError('Length of edge type list must match the number of '
'edges in the metagraph. {} vs {}'.format(
len(etypes), metagraph.number_of_edges()))
raise DGLError(
"Length of edge type list must match the number of "
"edges in the metagraph. {} vs {}".format(
len(etypes), metagraph.number_of_edges()
)
)
if len(ntypes) != metagraph.number_of_nodes():
raise DGLError('Length of nodes type list must match the number of '
'nodes in the metagraph. {} vs {}'.format(
len(ntypes), metagraph.number_of_nodes()))
if (len(etypes) == 1 and len(ntypes) == 1):
raise DGLError(
"Length of nodes type list must match the number of "
"nodes in the metagraph. {} vs {}".format(
len(ntypes), metagraph.number_of_nodes()
)
)
if len(etypes) == 1 and len(ntypes) == 1:
return [(ntypes[0], etypes[0], ntypes[0])]
src, dst, eid = metagraph.edges(order="eid")
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
rst = [
(ntypes[sid], etypes[eid], ntypes[did])
for sid, did, eid in zip(src, dst, eid)
]
return rst
def find_src_dst_ntypes(ntypes, metagraph):
"""Internal function to split ntypes into SRC and DST categories.
......@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph):
return None
else:
src, dst = ret
srctypes = {ntypes[tid] : tid for tid in src}
dsttypes = {ntypes[tid] : tid for tid in dst}
srctypes = {ntypes[tid]: tid for tid in src}
dsttypes = {ntypes[tid]: tid for tid in dst}
return srctypes, dsttypes
def pad_tuple(tup, length, pad_val=None):
"""Pad the given tuple to the given length.
......@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None):
Return None if pad fails.
"""
if not isinstance(tup, tuple):
tup = (tup, )
tup = (tup,)
if len(tup) > length:
return None
elif len(tup) == length:
......@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None):
else:
return tup + (pad_val,) * (length - len(tup))
def reduce_dict_data(frames, reducer, order=None):
"""Merge tensor dictionaries into one. Resolve conflict fields using reducer.
......@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None):
dict[str, Tensor]
Merged frame
"""
if len(frames) == 1 and reducer != 'stack':
if len(frames) == 1 and reducer != "stack":
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
return frames[0]
if callable(reducer):
merger = reducer
elif reducer == 'stack':
elif reducer == "stack":
# Stack order does not matter. However, it must be consistent!
if order:
assert len(order) == len(frames)
sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
frames = list(zip(*sorted_with_key))[0]
def merger(flist):
return F.stack(flist, 1)
else:
redfn = getattr(F, reducer, None)
if redfn is None:
raise DGLError('Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".')
raise DGLError(
"Invalid cross type reducer. Must be one of "
'"sum", "max", "min", "mean" or "stack".'
)
def merger(flist):
return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
keys = set()
for frm in frames:
keys.update(frm.keys())
......@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None):
ret[k] = merger(flist)
return ret
def combine_frames(frames, ids, col_names=None):
"""Merge the frames into one frame, taking the common columns.
......@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None):
for key, scheme in list(schemes.items()):
if key in frame.schemes:
if frame.schemes[key] != scheme:
raise DGLError('Cannot concatenate column %s with shape %s and shape %s' %
(key, frame.schemes[key], scheme))
raise DGLError(
"Cannot concatenate column %s with shape %s and shape %s"
% (key, frame.schemes[key], scheme)
)
else:
del schemes[key]
......@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None):
cols = {key: F.cat(to_cat(key), dim=0) for key in schemes}
return Frame(cols)
def combine_names(names, ids=None):
"""Combine the selected names into one new name.
......@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None):
str
"""
if ids is None:
return '+'.join(sorted(names))
return "+".join(sorted(names))
else:
selected = sorted([names[i] for i in ids])
return '+'.join(selected)
return "+".join(selected)
class DGLBlock(DGLGraph):
"""Subclass that signifies the graph is a block created from
:func:`dgl.to_block`.
"""
# (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag.
is_block = True
def __repr__(self):
if len(self.srctypes) == 1 and len(self.dsttypes) == 1 and len(self.etypes) == 1:
ret = 'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})'
if (
len(self.srctypes) == 1
and len(self.dsttypes) == 1
and len(self.etypes) == 1
):
ret = "Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})"
return ret.format(
srcnode=self.number_of_src_nodes(),
dstnode=self.number_of_dst_nodes(),
edge=self.number_of_edges())
edge=self.number_of_edges(),
)
else:
ret = ('Block(num_src_nodes={srcnode},\n'
' num_dst_nodes={dstnode},\n'
' num_edges={edge},\n'
' metagraph={meta})')
nsrcnode_dict = {ntype : self.number_of_src_nodes(ntype)
for ntype in self.srctypes}
ndstnode_dict = {ntype : self.number_of_dst_nodes(ntype)
for ntype in self.dsttypes}
nedge_dict = {etype : self.number_of_edges(etype)
for etype in self.canonical_etypes}
ret = (
"Block(num_src_nodes={srcnode},\n"
" num_dst_nodes={dstnode},\n"
" num_edges={edge},\n"
" metagraph={meta})"
)
nsrcnode_dict = {
ntype: self.number_of_src_nodes(ntype)
for ntype in self.srctypes
}
ndstnode_dict = {
ntype: self.number_of_dst_nodes(ntype)
for ntype in self.dsttypes
}
nedge_dict = {
etype: self.number_of_edges(etype)
for etype in self.canonical_etypes
}
meta = str(self.metagraph().edges(keys=True))
return ret.format(
srcnode=nsrcnode_dict, dstnode=ndstnode_dict, edge=nedge_dict, meta=meta)
srcnode=nsrcnode_dict,
dstnode=ndstnode_dict,
edge=nedge_dict,
meta=meta,
)
def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
......@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
srctype, etype, dsttype = graph.canonical_etypes[0]
# create graph
hgidx = heterograph_index.create_unitgraph_from_coo(
2, len(unique_src), len(unique_dst), new_u, new_v, ['coo', 'csr', 'csc'])
2, len(unique_src), len(unique_dst), new_u, new_v, ["coo", "csr", "csc"]
)
# create frame
srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(unique_src)
srcframe = graph._node_frames[graph.get_ntype_id(srctype)].subframe(
unique_src
)
srcframe[NID] = unique_src
dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(unique_dst)
dstframe = graph._node_frames[graph.get_ntype_id(dsttype)].subframe(
unique_dst
)
dstframe[NID] = unique_dst
eframe = graph._edge_frames[0].subframe(eid)
eframe[EID] = eid
return DGLGraph(hgidx, ([srctype], [dsttype]), [etype],
node_frames=[srcframe, dstframe],
edge_frames=[eframe]), unique_src, unique_dst, eid
return (
DGLGraph(
hgidx,
([srctype], [dsttype]),
[etype],
node_frames=[srcframe, dstframe],
edge_frames=[eframe],
),
unique_src,
unique_dst,
eid,
)
_init_api("dgl.heterograph")
......@@ -7,12 +7,11 @@ import sys
import numpy as np
import scipy
from . import backend as F
from . import utils
from . import backend as F, utils
from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object
from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning
from .base import dgl_warning, DGLError
from .graph_index import from_coo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment