Unverified Commit 3e5137fe authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Autoformat python dgl (Part 2). (#5332)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 6e58f5f1
...@@ -10,7 +10,7 @@ __all__ = ["copy_u", "copy_v"] ...@@ -10,7 +10,7 @@ __all__ = ["copy_u", "copy_v"]
####################################################### #######################################################
def copy_u(g, x_node, etype = None): def copy_u(g, x_node, etype=None):
"""Compute new edge data by fetching from source node data. """Compute new edge data by fetching from source node data.
Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph
...@@ -64,7 +64,7 @@ def copy_u(g, x_node, etype = None): ...@@ -64,7 +64,7 @@ def copy_u(g, x_node, etype = None):
return ops.gsddmm(etype_subg, "copy_lhs", x_node, None) return ops.gsddmm(etype_subg, "copy_lhs", x_node, None)
def copy_v(g, x_node, etype = None): def copy_v(g, x_node, etype=None):
"""Compute new edge data by fetching from destination node data. """Compute new edge data by fetching from destination node data.
Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph
...@@ -212,9 +212,11 @@ Examples ...@@ -212,9 +212,11 @@ Examples
(500, 5) (500, 5)
""" """
def func(g, x_node, y_node, etype = None): def func(g, x_node, y_node, etype=None):
etype_subg = g if etype is None else g[etype] etype_subg = g if etype is None else g[etype]
return ops.gsddmm(etype_subg, op, x_node, y_node, lhs_target="u", rhs_target="v") return ops.gsddmm(
etype_subg, op, x_node, y_node, lhs_target="u", rhs_target="v"
)
func.__name__ = name func.__name__ = name
func.__doc__ = docstring func.__doc__ = docstring
......
...@@ -15,14 +15,14 @@ import numpy as _np ...@@ -15,14 +15,14 @@ import numpy as _np
from . import backend as F from . import backend as F
from ._ffi.function import _init_api from ._ffi.function import _init_api
from ._ffi.ndarray import ( from ._ffi.ndarray import (
DGLContext,
DGLDataType,
NDArrayBase,
_set_class_ndarray, _set_class_ndarray,
context, context,
DGLContext,
DGLDataType,
empty, empty,
empty_shared_mem, empty_shared_mem,
from_dlpack, from_dlpack,
NDArrayBase,
numpyasarray, numpyasarray,
) )
from ._ffi.object import ObjectBase, register_object from ._ffi.object import ObjectBase, register_object
......
"""dgl edge_softmax operator module.""" """dgl edge_softmax operator module."""
from ..backend import astype from ..backend import (
from ..backend import edge_softmax as edge_softmax_internal astype,
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal edge_softmax as edge_softmax_internal,
edge_softmax_hetero as edge_softmax_hetero_internal,
)
from ..base import ALL, is_all from ..base import ALL, is_all
__all__ = ["edge_softmax"] __all__ = ["edge_softmax"]
......
...@@ -3,8 +3,10 @@ import sys ...@@ -3,8 +3,10 @@ import sys
from itertools import product from itertools import product
from .. import backend as F from .. import backend as F
from ..backend import gsddmm as gsddmm_internal from ..backend import (
from ..backend import gsddmm_hetero as gsddmm_internal_hetero gsddmm as gsddmm_internal,
gsddmm_hetero as gsddmm_internal_hetero,
)
__all__ = ["gsddmm", "copy_u", "copy_v", "copy_e"] __all__ = ["gsddmm", "copy_u", "copy_v", "copy_e"]
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import sys import sys
from .. import backend as F from .. import backend as F
from ..backend import gspmm as gspmm_internal from ..backend import (
from ..backend import gspmm_hetero as gspmm_internal_hetero gspmm as gspmm_internal,
gspmm_hetero as gspmm_internal_hetero,
)
__all__ = ["gspmm"] __all__ = ["gspmm"]
......
...@@ -5,8 +5,7 @@ import time ...@@ -5,8 +5,7 @@ import time
import numpy as np import numpy as np
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import EID, ETYPE, NID, NTYPE from .base import EID, ETYPE, NID, NTYPE
from .heterograph import DGLGraph from .heterograph import DGLGraph
......
"""Module for message propagation.""" """Module for message propagation."""
from __future__ import absolute_import from __future__ import absolute_import
from . import backend as F from . import backend as F, traversal as trv
from . import traversal as trv
from .heterograph import DGLGraph from .heterograph import DGLGraph
__all__ = [ __all__ = [
......
"""Python interfaces to DGL random number generators.""" """Python interfaces to DGL random number generators."""
import numpy as np import numpy as np
from . import backend as F from . import backend as F, ndarray as nd
from . import ndarray as nd
from ._ffi.function import _init_api from ._ffi.function import _init_api
__all__ = ["seed"] __all__ = ["seed"]
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import backend as F from . import backend as F
from .base import DGLError, dgl_warning from .base import dgl_warning, DGLError
from .ops import segment from .ops import segment
__all__ = [ __all__ = [
......
...@@ -19,14 +19,12 @@ ...@@ -19,14 +19,12 @@
"""Labor sampling APIs""" """Labor sampling APIs"""
from .. import backend as F, ndarray as nd, utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F
from ..base import DGLError from ..base import DGLError
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .. import ndarray as nd
from .. import utils
from .utils import EidExcluder
from ..random import choice from ..random import choice
from .utils import EidExcluder
__all__ = ["sample_labors"] __all__ = ["sample_labors"]
...@@ -230,7 +228,10 @@ def sample_labors( ...@@ -230,7 +228,10 @@ def sample_labors(
if output_device is None: if output_device is None:
return (frontier, importances) return (frontier, importances)
else: else:
return (frontier.to(output_device), list(map(lambda x: x.to(output_device), importances))) return (
frontier.to(output_device),
list(map(lambda x: x.to(output_device), importances)),
)
def _sample_labors( def _sample_labors(
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
from numpy.polynomial import polynomial from numpy.polynomial import polynomial
from .. import backend as F from .. import backend as F, utils
from .. import utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
......
"""Neighbor sampling APIs""" """Neighbor sampling APIs"""
from .. import backend as F, ndarray as nd, utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F
from ..base import DGLError, EID from ..base import DGLError, EID
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .. import ndarray as nd
from .. import utils
from .utils import EidExcluder from .utils import EidExcluder
__all__ = [ __all__ = [
'sample_etype_neighbors', "sample_etype_neighbors",
'sample_neighbors', "sample_neighbors",
'sample_neighbors_biased', "sample_neighbors_biased",
'select_topk'] "select_topk",
]
def _prepare_edge_arrays(g, arg): def _prepare_edge_arrays(g, arg):
"""Converts the argument into a list of NDArrays. """Converts the argument into a list of NDArrays.
...@@ -42,8 +42,10 @@ def _prepare_edge_arrays(g, arg): ...@@ -42,8 +42,10 @@ def _prepare_edge_arrays(g, arg):
result = [ result = [
F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx)) F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx))
if x is None else x if x is None
for x in result] else x
for x in result
]
return result return result
elif arg is None: elif arg is None:
return [nd.array([], ctx=nd.cpu())] * len(g.etypes) return [nd.array([], ctx=nd.cpu())] * len(g.etypes)
...@@ -56,10 +58,21 @@ def _prepare_edge_arrays(g, arg): ...@@ -56,10 +58,21 @@ def _prepare_edge_arrays(g, arg):
arrays.append(nd.array([], ctx=nd.cpu())) arrays.append(nd.array([], ctx=nd.cpu()))
return arrays return arrays
def sample_etype_neighbors( def sample_etype_neighbors(
g, nodes, etype_offset, fanout, edge_dir='in', prob=None, g,
replace=False, copy_ndata=True, copy_edata=True, etype_sorted=False, nodes,
_dist_training=False, output_device=None): etype_offset,
fanout,
edge_dir="in",
prob=None,
replace=False,
copy_ndata=True,
copy_edata=True,
etype_sorted=False,
_dist_training=False,
output_device=None,
):
"""Sample neighboring edges of the given nodes and return the induced subgraph. """Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -144,7 +157,7 @@ def sample_etype_neighbors( ...@@ -144,7 +157,7 @@ def sample_etype_neighbors(
assert len(nodes) == 1, "The input graph should not have node types" assert len(nodes) == 1, "The input graph should not have node types"
nodes = list(nodes.values())[0] nodes = list(nodes.values())[0]
nodes = utils.prepare_tensor(g, nodes, 'nodes') nodes = utils.prepare_tensor(g, nodes, "nodes")
device = utils.context_of(nodes) device = utils.context_of(nodes)
nodes = F.to_dgl_nd(nodes) nodes = F.to_dgl_nd(nodes)
# treat etypes as int32, it is much cheaper than int64 # treat etypes as int32, it is much cheaper than int64
...@@ -154,8 +167,15 @@ def sample_etype_neighbors( ...@@ -154,8 +167,15 @@ def sample_etype_neighbors(
prob_array = _prepare_edge_arrays(g, prob) prob_array = _prepare_edge_arrays(g, prob)
subgidx = _CAPI_DGLSampleNeighborsEType( subgidx = _CAPI_DGLSampleNeighborsEType(
g._graph, nodes, etype_offset, fanout, edge_dir, prob_array, g._graph,
replace, etype_sorted) nodes,
etype_offset,
fanout,
edge_dir,
prob_array,
replace,
etype_sorted,
)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -178,12 +198,23 @@ def sample_etype_neighbors( ...@@ -178,12 +198,23 @@ def sample_etype_neighbors(
return ret if output_device is None else ret.to(output_device) return ret if output_device is None else ret.to(output_device)
DGLGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors) DGLGraph.sample_etype_neighbors = utils.alias_func(sample_etype_neighbors)
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
replace=False, copy_ndata=True, copy_edata=True, def sample_neighbors(
_dist_training=False, exclude_edges=None, g,
output_device=None): nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
copy_ndata=True,
copy_edata=True,
_dist_training=False,
exclude_edges=None,
output_device=None,
):
"""Sample neighboring edges of the given nodes and return the induced subgraph. """Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
...@@ -319,33 +350,60 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, ...@@ -319,33 +350,60 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
tensor([False, False, False]) tensor([False, False, False])
""" """
if F.device_type(g.device) == 'cpu' and not g.is_pinned(): if F.device_type(g.device) == "cpu" and not g.is_pinned():
frontier = _sample_neighbors( frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, g,
replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata, nodes,
exclude_edges=exclude_edges) fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
copy_ndata=copy_ndata,
copy_edata=copy_edata,
exclude_edges=exclude_edges,
)
else: else:
frontier = _sample_neighbors( frontier = _sample_neighbors(
g, nodes, fanout, edge_dir=edge_dir, prob=prob, g,
replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata) nodes,
fanout,
edge_dir=edge_dir,
prob=prob,
replace=replace,
copy_ndata=copy_ndata,
copy_edata=copy_edata,
)
if exclude_edges is not None: if exclude_edges is not None:
eid_excluder = EidExcluder(exclude_edges) eid_excluder = EidExcluder(exclude_edges)
frontier = eid_excluder(frontier) frontier = eid_excluder(frontier)
return frontier if output_device is None else frontier.to(output_device) return frontier if output_device is None else frontier.to(output_device)
def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
replace=False, copy_ndata=True, copy_edata=True, def _sample_neighbors(
_dist_training=False, exclude_edges=None): g,
nodes,
fanout,
edge_dir="in",
prob=None,
replace=False,
copy_ndata=True,
copy_edata=True,
_dist_training=False,
exclude_edges=None,
):
if not isinstance(nodes, dict): if not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError(
nodes = {g.ntypes[0] : nodes} "Must specify node type when the graph is not homogeneous."
)
nodes = {g.ntypes[0]: nodes}
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
if len(nodes) == 0: if len(nodes) == 0:
raise ValueError( raise ValueError(
"Got an empty dictionary in the nodes argument. " "Got an empty dictionary in the nodes argument. "
"Please pass in a dictionary with empty tensors as values instead.") "Please pass in a dictionary with empty tensors as values instead."
)
device = utils.context_of(nodes) device = utils.context_of(nodes)
ctx = utils.to_dgl_context(device) ctx = utils.to_dgl_context(device)
nodes_all_types = [] nodes_all_types = []
...@@ -362,8 +420,10 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, ...@@ -362,8 +420,10 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
fanout_array = [int(fanout)] * len(g.etypes) fanout_array = [int(fanout)] * len(g.etypes)
else: else:
if len(fanout) != len(g.etypes): if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type ' raise DGLError(
'if a dict is provided.') "Fan-out must be specified for each edge type "
"if a dict is provided."
)
fanout_array = [None] * len(g.etypes) fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items(): for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value fanout_array[g.get_etype_id(etype)] = value
...@@ -375,9 +435,11 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, ...@@ -375,9 +435,11 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
if exclude_edges is not None: if exclude_edges is not None:
if not isinstance(exclude_edges, dict): if not isinstance(exclude_edges, dict):
if len(g.etypes) > 1: if len(g.etypes) > 1:
raise DGLError("Must specify etype when the graph is not homogeneous.") raise DGLError(
exclude_edges = {g.canonical_etypes[0] : exclude_edges} "Must specify etype when the graph is not homogeneous."
exclude_edges = utils.prepare_tensor_dict(g, exclude_edges, 'edges') )
exclude_edges = {g.canonical_etypes[0]: exclude_edges}
exclude_edges = utils.prepare_tensor_dict(g, exclude_edges, "edges")
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
if etype in exclude_edges: if etype in exclude_edges:
excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype])) excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))
...@@ -385,8 +447,14 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, ...@@ -385,8 +447,14 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
excluded_edges_all_t.append(nd.array([], ctx=ctx)) excluded_edges_all_t.append(nd.array([], ctx=ctx))
subgidx = _CAPI_DGLSampleNeighbors( subgidx = _CAPI_DGLSampleNeighbors(
g._graph, nodes_all_types, fanout_array, edge_dir, prob_arrays, g._graph,
excluded_edges_all_t, replace) nodes_all_types,
fanout_array,
edge_dir,
prob_arrays,
excluded_edges_all_t,
replace,
)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -409,11 +477,22 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, ...@@ -409,11 +477,22 @@ def _sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None,
return ret return ret
DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors) DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors)
def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
tag_offset_name='_TAG_OFFSET', replace=False, def sample_neighbors_biased(
copy_ndata=True, copy_edata=True, output_device=None): g,
nodes,
fanout,
bias,
edge_dir="in",
tag_offset_name="_TAG_OFFSET",
replace=False,
copy_ndata=True,
copy_edata=True,
output_device=None,
):
r"""Sample neighboring edges of the given nodes and return the induced subgraph, where each r"""Sample neighboring edges of the given nodes and return the induced subgraph, where each
neighbor's probability to be picked is determined by its tag. neighbor's probability to be picked is determined by its tag.
...@@ -558,15 +637,22 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', ...@@ -558,15 +637,22 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
nodes_array = F.to_dgl_nd(nodes) nodes_array = F.to_dgl_nd(nodes)
bias_array = F.to_dgl_nd(bias) bias_array = F.to_dgl_nd(bias)
if edge_dir == 'in': if edge_dir == "in":
tag_offset_array = F.to_dgl_nd(g.dstdata[tag_offset_name]) tag_offset_array = F.to_dgl_nd(g.dstdata[tag_offset_name])
elif edge_dir == 'out': elif edge_dir == "out":
tag_offset_array = F.to_dgl_nd(g.srcdata[tag_offset_name]) tag_offset_array = F.to_dgl_nd(g.srcdata[tag_offset_name])
else: else:
raise DGLError("edge_dir can only be 'in' or 'out'") raise DGLError("edge_dir can only be 'in' or 'out'")
subgidx = _CAPI_DGLSampleNeighborsBiased(g._graph, nodes_array, fanout, bias_array, subgidx = _CAPI_DGLSampleNeighborsBiased(
tag_offset_array, edge_dir, replace) g._graph,
nodes_array,
fanout,
bias_array,
tag_offset_array,
edge_dir,
replace,
)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -581,10 +667,21 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in', ...@@ -581,10 +667,21 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
ret.edata[EID] = induced_edges[0] ret.edata[EID] = induced_edges[0]
return ret if output_device is None else ret.to(output_device) return ret if output_device is None else ret.to(output_device)
DGLGraph.sample_neighbors_biased = utils.alias_func(sample_neighbors_biased) DGLGraph.sample_neighbors_biased = utils.alias_func(sample_neighbors_biased)
def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
copy_ndata=True, copy_edata=True, output_device=None): def select_topk(
g,
k,
weight,
nodes=None,
edge_dir="in",
ascending=False,
copy_ndata=True,
copy_edata=True,
output_device=None,
):
"""Select the neighboring edges with k-largest (or k-smallest) weights of the given """Select the neighboring edges with k-largest (or k-smallest) weights of the given
nodes and return the induced subgraph. nodes and return the induced subgraph.
...@@ -669,12 +766,14 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, ...@@ -669,12 +766,14 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
} }
elif not isinstance(nodes, dict): elif not isinstance(nodes, dict):
if len(g.ntypes) > 1: if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.") raise DGLError(
nodes = {g.ntypes[0] : nodes} "Must specify node type when the graph is not homogeneous."
)
nodes = {g.ntypes[0]: nodes}
assert g.device == F.cpu(), "Graph must be on CPU." assert g.device == F.cpu(), "Graph must be on CPU."
# Parse nodes into a list of NDArrays. # Parse nodes into a list of NDArrays.
nodes = utils.prepare_tensor_dict(g, nodes, 'nodes') nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
device = utils.context_of(nodes) device = utils.context_of(nodes)
nodes_all_types = [] nodes_all_types = []
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -687,8 +786,10 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, ...@@ -687,8 +786,10 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
k_array = [int(k)] * len(g.etypes) k_array = [int(k)] * len(g.etypes)
else: else:
if len(k) != len(g.etypes): if len(k) != len(g.etypes):
raise DGLError('K value must be specified for each edge type ' raise DGLError(
'if a dict is provided.') "K value must be specified for each edge type "
"if a dict is provided."
)
k_array = [None] * len(g.etypes) k_array = [None] * len(g.etypes)
for etype, value in k.items(): for etype, value in k.items():
k_array[g.get_etype_id(etype)] = value k_array[g.get_etype_id(etype)] = value
...@@ -699,11 +800,20 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, ...@@ -699,11 +800,20 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
if weight in g.edges[etype].data: if weight in g.edges[etype].data:
weight_arrays.append(F.to_dgl_nd(g.edges[etype].data[weight])) weight_arrays.append(F.to_dgl_nd(g.edges[etype].data[weight]))
else: else:
raise DGLError('Edge weights "{}" do not exist for relation graph "{}".'.format( raise DGLError(
weight, etype)) 'Edge weights "{}" do not exist for relation graph "{}".'.format(
weight, etype
)
)
subgidx = _CAPI_DGLSampleNeighborsTopk( subgidx = _CAPI_DGLSampleNeighborsTopk(
g._graph, nodes_all_types, k_array, edge_dir, weight_arrays, bool(ascending)) g._graph,
nodes_all_types,
k_array,
edge_dir,
weight_arrays,
bool(ascending),
)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
...@@ -717,6 +827,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False, ...@@ -717,6 +827,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False,
utils.set_new_frames(ret, edge_frames=edge_frames) utils.set_new_frames(ret, edge_frames=edge_frames)
return ret if output_device is None else ret.to(output_device) return ret if output_device is None else ret.to(output_device)
DGLGraph.select_topk = utils.alias_func(select_topk) DGLGraph.select_topk = utils.alias_func(select_topk)
_init_api('dgl.sampling.neighbor', __name__) _init_api("dgl.sampling.neighbor", __name__)
"""Node2vec random walk""" """Node2vec random walk"""
from .. import backend as F from .. import backend as F, ndarray as nd, utils
from .. import ndarray as nd
from .. import utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
# pylint: disable=invalid-name # pylint: disable=invalid-name
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, convert, utils
from .. import convert, utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .randomwalks import random_walk from .randomwalks import random_walk
......
"""Random walk routines """Random walk routines
""" """
from .. import backend as F from .. import backend as F, ndarray as nd, utils
from .. import ndarray as nd
from .. import utils
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import DGLError from ..base import DGLError
......
"""Sampling utilities""" """Sampling utilities"""
from collections.abc import Mapping from collections.abc import Mapping
import numpy as np import numpy as np
from ..utils import recursive_apply, recursive_apply_pair from .. import backend as F, transforms, utils
from ..base import EID from ..base import EID
from .. import backend as F
from .. import transforms, utils from ..utils import recursive_apply, recursive_apply_pair
def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
"""Find the edges whose IDs in parent graph appeared in exclude_eids. """Find the edges whose IDs in parent graph appeared in exclude_eids.
...@@ -20,11 +22,13 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids): ...@@ -20,11 +22,13 @@ def _locate_eids_to_exclude(frontier_parent_eids, exclude_eids):
result[k] = np.isin(v, exclude_eids[k]).nonzero()[0] result[k] = np.isin(v, exclude_eids[k]).nonzero()[0]
return recursive_apply(result, F.zerocopy_from_numpy) return recursive_apply(result, F.zerocopy_from_numpy)
class EidExcluder(object): class EidExcluder(object):
"""Class that finds the edges whose IDs in parent graph appeared in exclude_eids. """Class that finds the edges whose IDs in parent graph appeared in exclude_eids.
The edge IDs can be both CPU and GPU tensors. The edge IDs can be both CPU and GPU tensors.
""" """
def __init__(self, exclude_eids): def __init__(self, exclude_eids):
device = None device = None
if isinstance(exclude_eids, Mapping): if isinstance(exclude_eids, Mapping):
...@@ -42,13 +46,14 @@ class EidExcluder(object): ...@@ -42,13 +46,14 @@ class EidExcluder(object):
# should just use that irregardless of the device. # should just use that irregardless of the device.
self._exclude_eids = ( self._exclude_eids = (
recursive_apply(exclude_eids, F.zerocopy_to_numpy) recursive_apply(exclude_eids, F.zerocopy_to_numpy)
if exclude_eids is not None else None) if exclude_eids is not None
else None
)
else: else:
self._filter = recursive_apply(exclude_eids, utils.Filter) self._filter = recursive_apply(exclude_eids, utils.Filter)
def _find_indices(self, parent_eids): def _find_indices(self, parent_eids):
""" Find the set of edge indices to remove. """Find the set of edge indices to remove."""
"""
if self._exclude_eids is not None: if self._exclude_eids is not None:
parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy) parent_eids_np = recursive_apply(parent_eids, F.zerocopy_to_numpy)
return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids) return _locate_eids_to_exclude(parent_eids_np, self._exclude_eids)
...@@ -68,10 +73,16 @@ class EidExcluder(object): ...@@ -68,10 +73,16 @@ class EidExcluder(object):
# So we need to test if located_eids is empty, and do the remapping ourselves. # So we need to test if located_eids is empty, and do the remapping ourselves.
if len(located_eids) > 0: if len(located_eids) > 0:
frontier = transforms.remove_edges( frontier = transforms.remove_edges(
frontier, located_eids, store_ids=True) frontier, located_eids, store_ids=True
if weights is not None and weights[0].shape[0] == frontier.num_edges(): )
if (
weights is not None
and weights[0].shape[0] == frontier.num_edges()
):
weights[0] = F.gather_row(weights[0], frontier.edata[EID]) weights[0] = F.gather_row(weights[0], frontier.edata[EID])
frontier.edata[EID] = F.gather_row(parent_eids, frontier.edata[EID]) frontier.edata[EID] = F.gather_row(
parent_eids, frontier.edata[EID]
)
else: else:
# (BarclayII) remove_edges only accepts removing one type of edges, # (BarclayII) remove_edges only accepts removing one type of edges,
# so I need to keep track of the edge IDs left one by one. # so I need to keep track of the edge IDs left one by one.
...@@ -79,9 +90,16 @@ class EidExcluder(object): ...@@ -79,9 +90,16 @@ class EidExcluder(object):
for i, (k, v) in enumerate(located_eids.items()): for i, (k, v) in enumerate(located_eids.items()):
if len(v) > 0: if len(v) > 0:
frontier = transforms.remove_edges( frontier = transforms.remove_edges(
frontier, v, etype=k, store_ids=True) frontier, v, etype=k, store_ids=True
new_eids[k] = F.gather_row(parent_eids[k], frontier.edges[k].data[EID]) )
if weights is not None and weights[i].shape[0] == frontier.num_edges(k): new_eids[k] = F.gather_row(
weights[i] = F.gather_row(weights[i], frontier.edges[k].data[EID]) parent_eids[k], frontier.edges[k].data[EID]
)
if weights is not None and weights[i].shape[
0
] == frontier.num_edges(k):
weights[i] = F.gather_row(
weights[i], frontier.edges[k].data[EID]
)
frontier.edata[EID] = new_eids frontier.edata[EID] = new_eids
return frontier if weights is None else (frontier, weights) return frontier if weights is None else (frontier, weights)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F
from .base import FeatureStorage, ThreadedFuture, register_storage_wrapper from .base import FeatureStorage, register_storage_wrapper, ThreadedFuture
@register_storage_wrapper(np.memmap) @register_storage_wrapper(np.memmap)
......
...@@ -5,8 +5,7 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin ...@@ -5,8 +5,7 @@ For stochastic subgraph extraction, please see functions under :mod:`dgl.samplin
""" """
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F, graph_index, heterograph_index, utils
from . import graph_index, heterograph_index, utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from .heterograph import DGLGraph from .heterograph import DGLGraph
...@@ -177,12 +176,7 @@ DGLGraph.subgraph = utils.alias_func(node_subgraph) ...@@ -177,12 +176,7 @@ DGLGraph.subgraph = utils.alias_func(node_subgraph)
def edge_subgraph( def edge_subgraph(
graph, graph, edges, *, relabel_nodes=True, store_ids=True, output_device=None
edges,
*,
relabel_nodes=True,
store_ids=True,
output_device=None
): ):
"""Return a subgraph induced on the given edges. """Return a subgraph induced on the given edges.
......
"""Module for graph traversal methods.""" """Module for graph traversal methods."""
from __future__ import absolute_import from __future__ import absolute_import
from . import backend as F from . import backend as F, utils
from . import utils
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .heterograph import DGLGraph from .heterograph import DGLGraph
......
...@@ -9,10 +9,9 @@ from functools import wraps ...@@ -9,10 +9,9 @@ from functools import wraps
import numpy as np import numpy as np
from .. import backend as F from .. import backend as F, ndarray as nd
from .. import ndarray as nd
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import EID, NID, DGLError, dgl_warning from ..base import dgl_warning, DGLError, EID, NID
def is_listlike(data): def is_listlike(data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment