"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f35ac5443e5028bd677d844cb1e7d724c4d41900"
Unverified Commit d1827488 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 3e5137fe
"""Base classes and functionalities for dataloaders""" """Base classes and functionalities for dataloaders"""
from collections.abc import Mapping
import inspect import inspect
from ..base import NID, EID from collections.abc import Mapping
from ..convert import heterograph
from .. import backend as F from .. import backend as F
from ..transforms import compact_graphs from ..base import EID, NID
from ..convert import heterograph
from ..frame import LazyFeature from ..frame import LazyFeature
from ..utils import recursive_apply, context_of from ..transforms import compact_graphs
from ..utils import context_of, recursive_apply
def _set_lazy_features(x, xdata, feature_names): def _set_lazy_features(x, xdata, feature_names):
if feature_names is None: if feature_names is None:
...@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names): ...@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for type_, names in feature_names.items(): for type_, names in feature_names.items():
x[type_].data.update({k: LazyFeature(k) for k in names}) x[type_].data.update({k: LazyFeature(k) for k in names})
def set_node_lazy_features(g, feature_names): def set_node_lazy_features(g, feature_names):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization. """Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
...@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names): ...@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.nodes, g.ndata, feature_names) return _set_lazy_features(g.nodes, g.ndata, feature_names)
def set_edge_lazy_features(g, feature_names): def set_edge_lazy_features(g, feature_names):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization. """Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
...@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names): ...@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.edges, g.edata, feature_names) return _set_lazy_features(g.edges, g.edata, feature_names)
def set_src_lazy_features(g, feature_names): def set_src_lazy_features(g, feature_names):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization. """Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
...@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names): ...@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.srcnodes, g.srcdata, feature_names) return _set_lazy_features(g.srcnodes, g.srcdata, feature_names)
def set_dst_lazy_features(g, feature_names): def set_dst_lazy_features(g, feature_names):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization. """Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
...@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names): ...@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
""" """
return _set_lazy_features(g.dstnodes, g.dstdata, feature_names) return _set_lazy_features(g.dstnodes, g.dstdata, feature_names)
class Sampler(object): class Sampler(object):
"""Base class for graph samplers. """Base class for graph samplers.
...@@ -171,6 +178,7 @@ class Sampler(object): ...@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices): def sample(self, g, indices):
return g.subgraph(indices) return g.subgraph(indices)
""" """
def sample(self, g, indices): def sample(self, g, indices):
"""Abstract sample method. """Abstract sample method.
...@@ -183,6 +191,7 @@ class Sampler(object): ...@@ -183,6 +191,7 @@ class Sampler(object):
""" """
raise NotImplementedError raise NotImplementedError
class BlockSampler(Sampler): class BlockSampler(Sampler):
"""Base class for sampling mini-batches in the form of Message-passing """Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs). Flow Graphs (MFGs).
...@@ -211,8 +220,14 @@ class BlockSampler(Sampler): ...@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes. minibatch of seed nodes.
""" """
def __init__(self, prefetch_node_feats=None, prefetch_labels=None,
prefetch_edge_feats=None, output_device=None): def __init__(
self,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__() super().__init__()
self.prefetch_node_feats = prefetch_node_feats or [] self.prefetch_node_feats = prefetch_node_feats or []
self.prefetch_labels = prefetch_labels or [] self.prefetch_labels = prefetch_labels or []
...@@ -238,7 +253,9 @@ class BlockSampler(Sampler): ...@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features(block, self.prefetch_edge_feats) set_edge_lazy_features(block, self.prefetch_edge_feats)
return input_nodes, output_nodes, blocks return input_nodes, output_nodes, blocks
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes.""" """Sample a list of blocks from the given seed nodes."""
result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids) result = self.sample_blocks(g, seed_nodes, exclude_eids=exclude_eids)
return self.assign_lazy_features(result) return self.assign_lazy_features(result)
...@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): ...@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = { exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()} for k, v in eids.items()
}
else: else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = { reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v) g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()} for k, v in reverse_etype_map.items()
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) }
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs): def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
if exclude_mode is None: if exclude_mode is None:
return None return None
elif callable(exclude_mode): elif callable(exclude_mode):
return exclude_mode(eids) return exclude_mode(eids)
elif F.is_tensor(exclude_mode) or ( elif F.is_tensor(exclude_mode) or (
isinstance(exclude_mode, Mapping) and isinstance(exclude_mode, Mapping)
all(F.is_tensor(v) for v in exclude_mode.values())): and all(F.is_tensor(v) for v in exclude_mode.values())
):
return exclude_mode return exclude_mode
elif exclude_mode == 'self': elif exclude_mode == "self":
return eids return eids
elif exclude_mode == 'reverse_id': elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) return _find_exclude_eids_with_reverse_id(
elif exclude_mode == 'reverse_types': g, eids, kwargs["reverse_eid_map"]
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) )
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else: else:
raise ValueError('unsupported mode {}'.format(exclude_mode)) raise ValueError("unsupported mode {}".format(exclude_mode))
def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=None,
output_device=None): def find_exclude_eids(
g,
seed_edges,
exclude,
reverse_eids=None,
reverse_etypes=None,
output_device=None,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`. """Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters Parameters
...@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes= ...@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude, exclude,
seed_edges, seed_edges,
reverse_eid_map=reverse_eids, reverse_eid_map=reverse_eids,
reverse_etype_map=reverse_etypes) reverse_etype_map=reverse_etypes,
)
if exclude_eids is not None and output_device is not None: if exclude_eids is not None and output_device is not None:
exclude_eids = recursive_apply(exclude_eids, lambda x: F.copy_to(x, output_device)) exclude_eids = recursive_apply(
exclude_eids, lambda x: F.copy_to(x, output_device)
)
return exclude_eids return exclude_eids
class EdgePredictionSampler(Sampler): class EdgePredictionSampler(Sampler):
"""Sampler class that wraps an existing sampler for node classification into another """Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction. one for edge classification or link prediction.
...@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler): ...@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
-------- --------
as_edge_prediction_sampler as_edge_prediction_sampler
""" """
def __init__(self, sampler, exclude=None, reverse_eids=None,
reverse_etypes=None, negative_sampler=None, prefetch_labels=None): def __init__(
self,
sampler,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
super().__init__() super().__init__()
# Check if the sampler's sample method has an optional third argument. # Check if the sampler's sample method has an optional third argument.
argspec = inspect.getfullargspec(sampler.sample) argspec = inspect.getfullargspec(sampler.sample)
if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids'] if len(argspec.args) < 4: # ['self', 'g', 'indices', 'exclude_eids']
raise TypeError( raise TypeError(
"This sampler does not support edge or link prediction; please add an" "This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method.") "optional third argument for edge IDs to exclude in its sample() method."
)
self.reverse_eids = reverse_eids self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes self.reverse_etypes = reverse_etypes
self.exclude = exclude self.exclude = exclude
...@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler): ...@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def _build_neg_graph(self, g, seed_edges): def _build_neg_graph(self, g, seed_edges):
neg_srcdst = self.negative_sampler(g, seed_edges) neg_srcdst = self.negative_sampler(g, seed_edges)
if not isinstance(neg_srcdst, Mapping): if not isinstance(neg_srcdst, Mapping):
assert len(g.canonical_etypes) == 1, \ assert len(g.canonical_etypes) == 1, (
'graph has multiple or no edge types; '\ "graph has multiple or no edge types; "
'please return a dict in negative sampler.' "please return a dict in negative sampler."
)
neg_srcdst = {g.canonical_etypes[0]: neg_srcdst} neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}
dtype = F.dtype(list(neg_srcdst.values())[0][0]) dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = context_of(seed_edges) if seed_edges is not None else g.device ctx = context_of(seed_edges) if seed_edges is not None else g.device
neg_edges = { neg_edges = {
etype: neg_srcdst.get(etype, etype: neg_srcdst.get(
(F.copy_to(F.tensor([], dtype), ctx=ctx), etype,
F.copy_to(F.tensor([], dtype), ctx=ctx))) (
for etype in g.canonical_etypes} F.copy_to(F.tensor([], dtype), ctx=ctx),
F.copy_to(F.tensor([], dtype), ctx=ctx),
),
)
for etype in g.canonical_etypes
}
neg_pair_graph = heterograph( neg_pair_graph = heterograph(
neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}) neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
)
return neg_pair_graph return neg_pair_graph
def assign_lazy_features(self, result): def assign_lazy_features(self, result):
...@@ -390,7 +445,7 @@ class EdgePredictionSampler(Sampler): ...@@ -390,7 +445,7 @@ class EdgePredictionSampler(Sampler):
# In-place updates # In-place updates
return result return result
def sample(self, g, seed_edges): # pylint: disable=arguments-differ def sample(self, g, seed_edges): # pylint: disable=arguments-differ
"""Samples a list of blocks, as well as a subgraph containing the sampled """Samples a list of blocks, as well as a subgraph containing the sampled
edges from the original graph. edges from the original graph.
...@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler): ...@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges. negative pairs as edges.
""" """
if isinstance(seed_edges, Mapping): if isinstance(seed_edges, Mapping):
seed_edges = {g.to_canonical_etype(k): v for k, v in seed_edges.items()} seed_edges = {
g.to_canonical_etype(k): v for k, v in seed_edges.items()
}
exclude = self.exclude exclude = self.exclude
pair_graph = g.edge_subgraph( pair_graph = g.edge_subgraph(
seed_edges, relabel_nodes=False, output_device=self.output_device) seed_edges, relabel_nodes=False, output_device=self.output_device
)
eids = pair_graph.edata[EID] eids = pair_graph.edata[EID]
if self.negative_sampler is not None: if self.negative_sampler is not None:
...@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler): ...@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
exclude_eids = find_exclude_eids( exclude_eids = find_exclude_eids(
g, seed_edges, exclude, self.reverse_eids, self.reverse_etypes, g,
self.output_device) seed_edges,
exclude,
input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids) self.reverse_eids,
self.reverse_etypes,
self.output_device,
)
input_nodes, _, blocks = self.sampler.sample(
g, seed_nodes, exclude_eids
)
if self.negative_sampler is None: if self.negative_sampler is None:
return self.assign_lazy_features((input_nodes, pair_graph, blocks)) return self.assign_lazy_features((input_nodes, pair_graph, blocks))
else: else:
return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks)) return self.assign_lazy_features(
(input_nodes, pair_graph, neg_graph, blocks)
)
def as_edge_prediction_sampler( def as_edge_prediction_sampler(
sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, sampler,
prefetch_labels=None): exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
prefetch_labels=None,
):
"""Create an edge-wise sampler from a node-wise sampler. """Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to For each batch of edges, the sampler applies the provided node-wise sampler to
...@@ -571,5 +644,10 @@ def as_edge_prediction_sampler( ...@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks) ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
""" """
return EdgePredictionSampler( return EdgePredictionSampler(
sampler, exclude=exclude, reverse_eids=reverse_eids, reverse_etypes=reverse_etypes, sampler,
negative_sampler=negative_sampler, prefetch_labels=prefetch_labels) exclude=exclude,
reverse_eids=reverse_eids,
reverse_etypes=reverse_etypes,
negative_sampler=negative_sampler,
prefetch_labels=prefetch_labels,
)
"""Distributed dataloaders. """Distributed dataloaders.
""" """
import inspect import inspect
from abc import ABC, abstractmethod, abstractproperty
from collections.abc import Mapping from collections.abc import Mapping
from abc import ABC, abstractproperty, abstractmethod
from .. import transforms from .. import backend as F, transforms, utils
from ..base import NID, EID from ..base import EID, NID
from .. import backend as F
from .. import utils
from ..convert import heterograph from ..convert import heterograph
from ..distributed import DistDataLoader from ..distributed import DistDataLoader
...@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map): ...@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids = {g.to_canonical_etype(k): v for k, v in eids.items()} eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
exclude_eids = { exclude_eids = {
k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0) k: F.cat([v, F.gather_row(reverse_eid_map[k], v)], 0)
for k, v in eids.items()} for k, v in eids.items()
}
else: else:
exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0) exclude_eids = F.cat([eids, F.gather_row(reverse_eid_map, eids)], 0)
return exclude_eids return exclude_eids
def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map): def _find_exclude_eids_with_reverse_types(g, eids, reverse_etype_map):
exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()} exclude_eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
reverse_etype_map = { reverse_etype_map = {
g.to_canonical_etype(k): g.to_canonical_etype(v) g.to_canonical_etype(k): g.to_canonical_etype(v)
for k, v in reverse_etype_map.items()} for k, v in reverse_etype_map.items()
exclude_eids.update({reverse_etype_map[k]: v for k, v in exclude_eids.items()}) }
exclude_eids.update(
{reverse_etype_map[k]: v for k, v in exclude_eids.items()}
)
return exclude_eids return exclude_eids
def _find_exclude_eids(g, exclude_mode, eids, **kwargs): def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`. """Find all edge IDs to exclude according to :attr:`exclude_mode`.
...@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs): ...@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
""" """
if exclude_mode is None: if exclude_mode is None:
return None return None
elif exclude_mode == 'self': elif exclude_mode == "self":
return eids return eids
elif exclude_mode == 'reverse_id': elif exclude_mode == "reverse_id":
return _find_exclude_eids_with_reverse_id(g, eids, kwargs['reverse_eid_map']) return _find_exclude_eids_with_reverse_id(
elif exclude_mode == 'reverse_types': g, eids, kwargs["reverse_eid_map"]
return _find_exclude_eids_with_reverse_types(g, eids, kwargs['reverse_etype_map']) )
elif exclude_mode == "reverse_types":
return _find_exclude_eids_with_reverse_types(
g, eids, kwargs["reverse_etype_map"]
)
else: else:
raise ValueError('unsupported mode {}'.format(exclude_mode)) raise ValueError("unsupported mode {}".format(exclude_mode))
class Collator(ABC): class Collator(ABC):
...@@ -100,6 +109,7 @@ class Collator(ABC): ...@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
@abstractproperty @abstractproperty
def dataset(self): def dataset(self):
"""Returns the dataset object of the collator.""" """Returns the dataset object of the collator."""
...@@ -122,6 +132,7 @@ class Collator(ABC): ...@@ -122,6 +132,7 @@ class Collator(ABC):
""" """
raise NotImplementedError raise NotImplementedError
class NodeCollator(Collator): class NodeCollator(Collator):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for """DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling. training node classification or regression on a single graph with neighborhood sampling.
...@@ -155,14 +166,16 @@ class NodeCollator(Collator): ...@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, g, nids, graph_sampler): def __init__(self, g, nids, graph_sampler):
self.g = g self.g = g
if not isinstance(nids, Mapping): if not isinstance(nids, Mapping):
assert len(g.ntypes) == 1, \ assert (
"nids should be a dict of node type and ids for graph with multiple node types" len(g.ntypes) == 1
), "nids should be a dict of node type and ids for graph with multiple node types"
self.graph_sampler = graph_sampler self.graph_sampler = graph_sampler
self.nids = utils.prepare_tensor_or_dict(g, nids, 'nids') self.nids = utils.prepare_tensor_or_dict(g, nids, "nids")
self._dataset = utils.maybe_flatten_dict(self.nids) self._dataset = utils.maybe_flatten_dict(self.nids)
@property @property
...@@ -197,12 +210,15 @@ class NodeCollator(Collator): ...@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g, items, 'items') items = utils.prepare_tensor_or_dict(self.g, items, "items")
input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(self.g, items) input_nodes, output_nodes, blocks = self.graph_sampler.sample_blocks(
self.g, items
)
return input_nodes, output_nodes, blocks return input_nodes, output_nodes, blocks
class EdgeCollator(Collator): class EdgeCollator(Collator):
"""DGL collator to combine edges and their computation dependencies within a minibatch for """DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph training edge classification, edge regression, or link prediction on a single graph
...@@ -380,12 +396,23 @@ class EdgeCollator(Collator): ...@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, g, eids, graph_sampler, g_sampling=None, exclude=None,
reverse_eids=None, reverse_etypes=None, negative_sampler=None): def __init__(
self,
g,
eids,
graph_sampler,
g_sampling=None,
exclude=None,
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
):
self.g = g self.g = g
if not isinstance(eids, Mapping): if not isinstance(eids, Mapping):
assert len(g.etypes) == 1, \ assert (
"eids should be a dict of etype and ids for graph with multiple etypes" len(g.etypes) == 1
), "eids should be a dict of etype and ids for graph with multiple etypes"
self.graph_sampler = graph_sampler self.graph_sampler = graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in # One may wish to iterate over the edges in one graph while perform sampling in
...@@ -404,7 +431,7 @@ class EdgeCollator(Collator): ...@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self.reverse_etypes = reverse_etypes self.reverse_etypes = reverse_etypes
self.negative_sampler = negative_sampler self.negative_sampler = negative_sampler
self.eids = utils.prepare_tensor_or_dict(g, eids, 'eids') self.eids = utils.prepare_tensor_or_dict(g, eids, "eids")
self._dataset = utils.maybe_flatten_dict(self.eids) self._dataset = utils.maybe_flatten_dict(self.eids)
@property @property
...@@ -415,7 +442,7 @@ class EdgeCollator(Collator): ...@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items) pair_graph = self.g.edge_subgraph(items)
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
...@@ -425,10 +452,12 @@ class EdgeCollator(Collator): ...@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self.exclude, self.exclude,
items, items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks( input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, blocks return input_nodes, pair_graph, blocks
...@@ -436,28 +465,39 @@ class EdgeCollator(Collator): ...@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if isinstance(items[0], tuple): if isinstance(items[0], tuple):
# returns a list of pairs: group them by node types into a dict # returns a list of pairs: group them by node types into a dict
items = utils.group_as_dict(items) items = utils.group_as_dict(items)
items = utils.prepare_tensor_or_dict(self.g_sampling, items, 'items') items = utils.prepare_tensor_or_dict(self.g_sampling, items, "items")
pair_graph = self.g.edge_subgraph(items, relabel_nodes=False) pair_graph = self.g.edge_subgraph(items, relabel_nodes=False)
induced_edges = pair_graph.edata[EID] induced_edges = pair_graph.edata[EID]
neg_srcdst = self.negative_sampler(self.g, items) neg_srcdst = self.negative_sampler(self.g, items)
if not isinstance(neg_srcdst, Mapping): if not isinstance(neg_srcdst, Mapping):
assert len(self.g.etypes) == 1, \ assert len(self.g.etypes) == 1, (
'graph has multiple or no edge types; '\ "graph has multiple or no edge types; "
'please return a dict in negative sampler.' "please return a dict in negative sampler."
)
neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst} neg_srcdst = {self.g.canonical_etypes[0]: neg_srcdst}
# Get dtype from a tuple of tensors # Get dtype from a tuple of tensors
dtype = F.dtype(list(neg_srcdst.values())[0][0]) dtype = F.dtype(list(neg_srcdst.values())[0][0])
ctx = F.context(pair_graph) ctx = F.context(pair_graph)
neg_edges = { neg_edges = {
etype: neg_srcdst.get(etype, (F.copy_to(F.tensor([], dtype), ctx), etype: neg_srcdst.get(
F.copy_to(F.tensor([], dtype), ctx))) etype,
for etype in self.g.canonical_etypes} (
F.copy_to(F.tensor([], dtype), ctx),
F.copy_to(F.tensor([], dtype), ctx),
),
)
for etype in self.g.canonical_etypes
}
neg_pair_graph = heterograph( neg_pair_graph = heterograph(
neg_edges, {ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes}) neg_edges,
{ntype: self.g.number_of_nodes(ntype) for ntype in self.g.ntypes},
)
pair_graph, neg_pair_graph = transforms.compact_graphs([pair_graph, neg_pair_graph]) pair_graph, neg_pair_graph = transforms.compact_graphs(
[pair_graph, neg_pair_graph]
)
pair_graph.edata[EID] = induced_edges pair_graph.edata[EID] = induced_edges
seed_nodes = pair_graph.ndata[NID] seed_nodes = pair_graph.ndata[NID]
...@@ -467,10 +507,12 @@ class EdgeCollator(Collator): ...@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self.exclude, self.exclude,
items, items,
reverse_eid_map=self.reverse_eids, reverse_eid_map=self.reverse_eids,
reverse_etype_map=self.reverse_etypes) reverse_etype_map=self.reverse_etypes,
)
input_nodes, _, blocks = self.graph_sampler.sample_blocks( input_nodes, _, blocks = self.graph_sampler.sample_blocks(
self.g_sampling, seed_nodes, exclude_eids=exclude_eids) self.g_sampling, seed_nodes, exclude_eids=exclude_eids
)
return input_nodes, pair_graph, neg_pair_graph, blocks return input_nodes, pair_graph, neg_pair_graph, blocks
...@@ -517,13 +559,14 @@ class EdgeCollator(Collator): ...@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def _remove_kwargs_dist(kwargs): def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs: if "num_workers" in kwargs:
del kwargs['num_workers'] del kwargs["num_workers"]
if 'pin_memory' in kwargs: if "pin_memory" in kwargs:
del kwargs['pin_memory'] del kwargs["pin_memory"]
print('Distributed DataLoaders do not support pin_memory.') print("Distributed DataLoaders do not support pin_memory.")
return kwargs return kwargs
class DistNodeDataLoader(DistDataLoader): class DistNodeDataLoader(DistDataLoader):
"""Sampled graph data loader over nodes for distributed graph storage. """Sampled graph data loader over nodes for distributed graph storage.
...@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader): ...@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
-------- --------
dgl.dataloading.DataLoader dgl.dataloading.DataLoader
""" """
def __init__(self, g, nids, graph_sampler, device=None, **kwargs): def __init__(self, g, nids, graph_sampler, device=None, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
...@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader): ...@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
if device is None: if device is None:
# for the distributed case default to the CPU # for the distributed case default to the CPU
device = 'cpu' device = "cpu"
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs # Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution # and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs) self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs) _remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset, super().__init__(
collate_fn=self.collator.collate, self.collator.dataset,
**dataloader_kwargs) collate_fn=self.collator.collate,
**dataloader_kwargs
)
self.device = device self.device = device
class DistEdgeDataLoader(DistDataLoader): class DistEdgeDataLoader(DistDataLoader):
"""Sampled graph data loader over edges for distributed graph storage. """Sampled graph data loader over edges for distributed graph storage.
...@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader): ...@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
-------- --------
dgl.dataloading.DataLoader dgl.dataloading.DataLoader
""" """
def __init__(self, g, eids, graph_sampler, device=None, **kwargs): def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
collator_kwargs = {} collator_kwargs = {}
dataloader_kwargs = {} dataloader_kwargs = {}
...@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader): ...@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if device is None: if device is None:
# for the distributed case default to the CPU # for the distributed case default to the CPU
device = 'cpu' device = "cpu"
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' assert (
device == "cpu"
), "Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs # Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution # and does not copy features. Fallback to normal solution
self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs)
_remove_kwargs_dist(dataloader_kwargs) _remove_kwargs_dist(dataloader_kwargs)
super().__init__(self.collator.dataset, super().__init__(
collate_fn=self.collator.collate, self.collator.dataset,
**dataloader_kwargs) collate_fn=self.collator.collate,
**dataloader_kwargs
)
self.device = device self.device = device
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
# #
"""Data loading components for labor sampling""" """Data loading components for labor sampling"""
from ..base import NID, EID from .. import backend as F
from ..base import EID, NID
from ..random import choice
from ..transforms import to_block from ..transforms import to_block
from .base import BlockSampler from .base import BlockSampler
from ..random import choice
from .. import backend as F
class LaborSampler(BlockSampler): class LaborSampler(BlockSampler):
...@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler): ...@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
) )
block.edata[EID] = eid block.edata[EID] = eid
if len(g.canonical_etypes) > 1: if len(g.canonical_etypes) > 1:
for etype, importance in zip( for etype, importance in zip(g.canonical_etypes, importances):
g.canonical_etypes, importances
):
if importance.shape[0] == block.num_edges(etype): if importance.shape[0] == block.num_edges(etype):
block.edata["edge_weights"][etype] = importance block.edata["edge_weights"][etype] = importance
elif importances[0].shape[0] == block.num_edges(): elif importances[0].shape[0] == block.num_edges():
......
"""Data loading components for neighbor sampling""" """Data loading components for neighbor sampling"""
from ..base import NID, EID from ..base import EID, NID
from ..transforms import to_block from ..transforms import to_block
from .base import BlockSampler from .base import BlockSampler
class NeighborSampler(BlockSampler): class NeighborSampler(BlockSampler):
"""Sampler that builds computational dependency of node representations via """Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN. neighbor sampling for multilayer GNN.
...@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler): ...@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, fanouts, edge_dir='in', prob=None, mask=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, def __init__(
output_device=None): self,
super().__init__(prefetch_node_feats=prefetch_node_feats, fanouts,
prefetch_labels=prefetch_labels, edge_dir="in",
prefetch_edge_feats=prefetch_edge_feats, prob=None,
output_device=output_device) mask=None,
replace=False,
prefetch_node_feats=None,
prefetch_labels=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__(
prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device,
)
self.fanouts = fanouts self.fanouts = fanouts
self.edge_dir = edge_dir self.edge_dir = edge_dir
if mask is not None and prob is not None: if mask is not None and prob is not None:
raise ValueError( raise ValueError(
'Mask and probability arguments are mutually exclusive. ' "Mask and probability arguments are mutually exclusive. "
'Consider multiplying the probability with the mask ' "Consider multiplying the probability with the mask "
'to achieve the same goal.') "to achieve the same goal."
)
self.prob = prob or mask self.prob = prob or mask
self.replace = replace self.replace = replace
...@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler): ...@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks = [] blocks = []
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors( frontier = g.sample_neighbors(
seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob, seed_nodes,
replace=self.replace, output_device=self.output_device, fanout,
exclude_edges=exclude_eids) edge_dir=self.edge_dir,
prob=self.prob,
replace=self.replace,
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID] eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes) block = to_block(frontier, seed_nodes)
block.edata[EID] = eid block.edata[EID] = eid
...@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler): ...@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return seed_nodes, output_nodes, blocks return seed_nodes, output_nodes, blocks
MultiLayerNeighborSampler = NeighborSampler MultiLayerNeighborSampler = NeighborSampler
class MultiLayerFullNeighborSampler(NeighborSampler): class MultiLayerFullNeighborSampler(NeighborSampler):
"""Sampler that builds computational dependency of node representations by taking messages """Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN. from all neighbors for multilayer GNN.
...@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler): ...@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and :ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`. :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
""" """
def __init__(self, num_layers, **kwargs): def __init__(self, num_layers, **kwargs):
super().__init__([-1] * num_layers, **kwargs) super().__init__([-1] * num_layers, **kwargs)
"""ShaDow-GNN subgraph samplers.""" """ShaDow-GNN subgraph samplers."""
from ..sampling.utils import EidExcluder
from .. import transforms from .. import transforms
from ..base import NID from ..base import NID
from .base import set_node_lazy_features, set_edge_lazy_features, Sampler from ..sampling.utils import EidExcluder
from .base import Sampler, set_edge_lazy_features, set_node_lazy_features
class ShaDowKHopSampler(Sampler): class ShaDowKHopSampler(Sampler):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow """K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
...@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler): ...@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works >>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p') >>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
""" """
def __init__(self, fanouts, replace=False, prob=None, prefetch_node_feats=None,
prefetch_edge_feats=None, output_device=None): def __init__(
self,
fanouts,
replace=False,
prob=None,
prefetch_node_feats=None,
prefetch_edge_feats=None,
output_device=None,
):
super().__init__() super().__init__()
self.fanouts = fanouts self.fanouts = fanouts
self.replace = replace self.replace = replace
...@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler): ...@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self.prefetch_edge_feats = prefetch_edge_feats self.prefetch_edge_feats = prefetch_edge_feats
self.output_device = output_device self.output_device = output_device
def sample(self, g, seed_nodes, exclude_eids=None): # pylint: disable=arguments-differ def sample(
self, g, seed_nodes, exclude_eids=None
): # pylint: disable=arguments-differ
"""Sampling function. """Sampling function.
Parameters Parameters
...@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler): ...@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes = seed_nodes output_nodes = seed_nodes
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors( frontier = g.sample_neighbors(
seed_nodes, fanout, output_device=self.output_device, seed_nodes,
replace=self.replace, prob=self.prob, exclude_edges=exclude_eids) fanout,
output_device=self.output_device,
replace=self.replace,
prob=self.prob,
exclude_edges=exclude_eids,
)
block = transforms.to_block(frontier, seed_nodes) block = transforms.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[NID] seed_nodes = block.srcdata[NID]
subg = g.subgraph(seed_nodes, relabel_nodes=True, output_device=self.output_device) subg = g.subgraph(
seed_nodes, relabel_nodes=True, output_device=self.output_device
)
if exclude_eids is not None: if exclude_eids is not None:
subg = EidExcluder(exclude_eids)(subg) subg = EidExcluder(exclude_eids)(subg)
......
...@@ -7,8 +7,7 @@ from itertools import product ...@@ -7,8 +7,7 @@ from itertools import product
from .base import BuiltinFunction, TargetCode from .base import BuiltinFunction, TargetCode
__all__ = ["copy_u", "copy_e", __all__ = ["copy_u", "copy_e", "BinaryMessageFunction", "CopyMessageFunction"]
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction): class MessageFunction(BuiltinFunction):
...@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction): ...@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
-------- --------
u_mul_e u_mul_e
""" """
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field): def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op self.binary_op = binary_op
self.lhs = lhs self.lhs = lhs
...@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction): ...@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
-------- --------
copy_u copy_u
""" """
def __init__(self, target, in_field, out_field): def __init__(self, target, in_field, out_field):
self.target = target self.target = target
self.in_field = in_field self.in_field = in_field
...@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op): ...@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
-------- --------
>>> import dgl >>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm') >>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(binary_op, """.format(
TargetCode.CODE2STR[_TARGET_MAP[lhs]], binary_op,
TargetCode.CODE2STR[_TARGET_MAP[rhs]], TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]], TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]], TargetCode.CODE2STR[_TARGET_MAP[lhs]],
name) TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name,
)
def func(lhs_field, rhs_field, out): def func(lhs_field, rhs_field, out):
return BinaryMessageFunction( return BinaryMessageFunction(
binary_op, _TARGET_MAP[lhs], binary_op,
_TARGET_MAP[rhs], lhs_field, rhs_field, out) _TARGET_MAP[lhs],
_TARGET_MAP[rhs],
lhs_field,
rhs_field,
out,
)
func.__name__ = name func.__name__ = name
func.__doc__ = docstring func.__doc__ = docstring
return func return func
...@@ -177,4 +186,5 @@ def _register_builtin_message_func(): ...@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr(sys.modules[__name__], func.__name__, func) setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__) __all__.append(func.__name__)
_register_builtin_message_func() _register_builtin_message_func()
"""Module for various graph generator functions.""" """Module for various graph generator functions."""
from . import backend as F from . import backend as F, convert, random
from . import convert, random
__all__ = ["rand_graph", "rand_bipartite"] __all__ = ["rand_graph", "rand_bipartite"]
......
"""Python interfaces to DGL farthest point sampler.""" """Python interfaces to DGL farthest point sampler."""
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, ndarray as nd
from .. import ndarray as nd
from .._ffi.base import DGLError from .._ffi.base import DGLError
from .._ffi.function import _init_api from .._ffi.function import _init_api
......
...@@ -5,11 +5,10 @@ import networkx as nx ...@@ -5,11 +5,10 @@ import networkx as nx
import numpy as np import numpy as np
import scipy import scipy
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from .base import DGLError, dgl_warning from .base import dgl_warning, DGLError
class BoolFlag(object): class BoolFlag(object):
......
This diff is collapsed.
...@@ -7,12 +7,11 @@ import sys ...@@ -7,12 +7,11 @@ import sys
import numpy as np import numpy as np
import scipy import scipy
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
from ._ffi.streams import to_dgl_stream_handle from ._ffi.streams import to_dgl_stream_handle
from .base import DGLError, dgl_warning from .base import dgl_warning, DGLError
from .graph_index import from_coo from .graph_index import from_coo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment