Unverified Commit a566b60b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent d1827488
...@@ -14,26 +14,26 @@ try: ...@@ -14,26 +14,26 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ( from ._cy3.core import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
else: else:
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ( from ._cy2.core import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import ( from ._ctypes.function import (
_set_class_function, _set_class_function,
_set_class_module, _set_class_module,
convert_to_dgl_func, convert_to_dgl_func,
FunctionBase as _FunctionBase,
) )
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
......
...@@ -9,12 +9,12 @@ import numpy as np ...@@ -9,12 +9,12 @@ import numpy as np
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import ( from .runtime_ctypes import (
dgl_shape_index_t,
DGLArray, DGLArray,
DGLArrayHandle, DGLArrayHandle,
DGLContext, DGLContext,
DGLDataType, DGLDataType,
TypeCode, TypeCode,
dgl_shape_index_t,
) )
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -24,29 +24,29 @@ try: ...@@ -24,29 +24,29 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import ( from ._cy3.core import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
else: else:
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import ( from ._cy2.core import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import ( from ._ctypes.ndarray import (
_from_dlpack, _from_dlpack,
_make_array, _make_array,
_reg_extension, _reg_extension,
_set_class_ndarray, _set_class_ndarray,
NDArrayBase as _NDArrayBase,
) )
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
from .. import _api_internal from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object from .object_generic import convert_to_object, ObjectGeneric
# pylint: disable=invalid-name # pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
...@@ -16,15 +16,12 @@ try: ...@@ -16,15 +16,12 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import ObjectBase as _ObjectBase from ._cy3.core import _register_object, ObjectBase as _ObjectBase
from ._cy3.core import _register_object
else: else:
from ._cy2.core import ObjectBase as _ObjectBase from ._cy2.core import _register_object, ObjectBase as _ObjectBase
from ._cy2.core import _register_object
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.object import ObjectBase as _ObjectBase from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
def _new_object(cls): def _new_object(cls):
......
"""Utilities for batching/unbatching graphs.""" """Utilities for batching/unbatching graphs."""
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F, convert, utils
from .base import ALL, is_all, DGLError, NID, EID from .base import ALL, DGLError, EID, is_all, NID
from .heterograph_index import disjoint_union, slice_gidx
from .heterograph import DGLGraph from .heterograph import DGLGraph
from . import convert from .heterograph_index import disjoint_union, slice_gidx
from . import utils
__all__ = ["batch", "unbatch", "slice_batch"]
__all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL): def batch(graphs, ndata=ALL, edata=ALL):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
...@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
unbatch unbatch
""" """
if len(graphs) == 0: if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.') raise DGLError("The input list of graphs cannot be empty.")
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None): if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format( raise DGLError(
type(ndata))) "Invalid argument ndata: must be a string list but got {}.".format(
type(ndata)
)
)
if not (is_all(edata) or isinstance(edata, list) or edata is None): if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format( raise DGLError(
type(edata))) "Invalid argument edata: must be a string list but got {}.".format(
type(edata)
)
)
if any(g.is_block for g in graphs): if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.") raise DGLError("Batching a MFG is not supported.")
...@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes] ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations] etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs]) gidx = disjoint_union(
graphs[0]._graph.metagraph, [g._graph for g in graphs]
)
retg = DGLGraph(gidx, ntypes, etypes) retg = DGLGraph(gidx, ntypes, etypes)
# Compute batch num nodes # Compute batch num nodes
...@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL): ...@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
# Batch node feature # Batch node feature
if ndata is not None: if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes): for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs) all_empty = all(
g._graph.number_of_nodes(ntype_id) == 0 for g in graphs
)
frames = [ frames = [
g._node_frames[ntype_id] for g in graphs g._node_frames[ntype_id]
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype)) ret_feat = _batch_feat_dicts(
frames, ndata, 'nodes["{}"].data'.format(ntype)
)
retg.nodes[ntype].data.update(ret_feat) retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature # Batch edge feature
if edata is not None: if edata is not None:
for etype_id, etype in zip(relation_ids, relations): for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs) all_empty = all(
g._graph.number_of_edges(etype_id) == 0 for g in graphs
)
frames = [ frames = [
g._edge_frames[etype_id] for g in graphs g._edge_frames[etype_id]
if g._graph.number_of_edges(etype_id) > 0 or all_empty] for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently # TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching. # we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype)) ret_feat = _batch_feat_dicts(
frames, edata, "edges[{}].data".format(etype)
)
retg.edges[etype].data.update(ret_feat) retg.edges[etype].data.update(ret_feat)
return retg return retg
def _batch_feat_dicts(frames, keys, feat_dict_name): def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries. """Internal function to batch feature dictionaries.
...@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name): ...@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
else: else:
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name) utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features # concat features
ret_feat = {k : F.cat([fd[k] for fd in frames], 0) for k in keys} ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat return ret_feat
def unbatch(g, node_split=None, edge_split=None): def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones. """Revert the batch operation by split the given graph into a list of small ones.
...@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None): ...@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
num_split = None num_split = None
# Parse node_split # Parse node_split
if node_split is None: if node_split is None:
node_split = {ntype : g.batch_num_nodes(ntype) for ntype in g.ntypes} node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping): elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1: if len(g.ntypes) != 1:
raise DGLError('Must provide a dictionary for argument node_split when' raise DGLError(
' there are multiple node types.') "Must provide a dictionary for argument node_split when"
node_split = {g.ntypes[0] : node_split} " there are multiple node types."
)
node_split = {g.ntypes[0]: node_split}
if node_split.keys() != set(g.ntypes): if node_split.keys() != set(g.ntypes):
raise DGLError('Must specify node_split for each node type.') raise DGLError("Must specify node_split for each node type.")
for split in node_split.values(): for split in node_split.values():
if num_split is not None and num_split != len(split): if num_split is not None and num_split != len(split):
raise DGLError('All node_split and edge_split must specify the same number' raise DGLError(
' of split sizes.') "All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split) num_split = len(split)
# Parse edge_split # Parse edge_split
if edge_split is None: if edge_split is None:
edge_split = {etype : g.batch_num_edges(etype) for etype in g.canonical_etypes} edge_split = {
etype: g.batch_num_edges(etype) for etype in g.canonical_etypes
}
elif not isinstance(edge_split, Mapping): elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1: if len(g.etypes) != 1:
raise DGLError('Must provide a dictionary for argument edge_split when' raise DGLError(
' there are multiple edge types.') "Must provide a dictionary for argument edge_split when"
edge_split = {g.canonical_etypes[0] : edge_split} " there are multiple edge types."
)
edge_split = {g.canonical_etypes[0]: edge_split}
if edge_split.keys() != set(g.canonical_etypes): if edge_split.keys() != set(g.canonical_etypes):
raise DGLError('Must specify edge_split for each canonical edge type.') raise DGLError("Must specify edge_split for each canonical edge type.")
for split in edge_split.values(): for split in edge_split.values():
if num_split is not None and num_split != len(split): if num_split is not None and num_split != len(split):
raise DGLError('All edge_split and edge_split must specify the same number' raise DGLError(
' of split sizes.') "All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split) num_split = len(split)
node_split = {k : F.asnumpy(split).tolist() for k, split in node_split.items()} node_split = {
edge_split = {k : F.asnumpy(split).tolist() for k, split in edge_split.items()} k: F.asnumpy(split).tolist() for k, split in node_split.items()
}
edge_split = {
k: F.asnumpy(split).tolist() for k, split in edge_split.items()
}
# Split edges for each relation # Split edges for each relation
edge_dict_per = [{} for i in range(num_split)] edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes: for rel in g.canonical_etypes:
srctype, etype, dsttype = rel srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0 srcnid_off = dstnid_off = 0
u, v = g.edges(order='eid', etype=rel) u, v = g.edges(order="eid", etype=rel)
us = F.split(u, edge_split[rel], 0) us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0) vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)): for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off) edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i] srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i] dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [{k : split[i] for k, split in node_split.items()} num_nodes_dict_per = [
for i in range(num_split)] {k: split[i] for k, split in node_split.items()}
for i in range(num_split)
]
# Create graphs # Create graphs
gs = [convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype) gs = [
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)] convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)
]
# Unbatch node features # Unbatch node features
for ntype in g.ntypes: for ntype in g.ntypes:
...@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None): ...@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
return gs return gs
def slice_batch(g, gid, store_ids=False): def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs. """Get a particular graph from a batch of graphs.
...@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False): ...@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0: if gid == 0:
start_nid.append(0) start_nid.append(0)
else: else:
start_nid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))) start_nid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))
)
start_eid = [] start_eid = []
num_edges = [] num_edges = []
...@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False): ...@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0: if gid == 0:
start_eid.append(0) start_eid.append(0)
else: else:
start_eid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))) start_eid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))
)
# Slice graph structure # Slice graph structure
gidx = slice_gidx(g._graph, utils.toindex(num_nodes), utils.toindex(start_nid), gidx = slice_gidx(
utils.toindex(num_edges), utils.toindex(start_eid)) g._graph,
utils.toindex(num_nodes),
utils.toindex(start_nid),
utils.toindex(num_edges),
utils.toindex(start_eid),
)
retg = DGLGraph(gidx, g.ntypes, g.etypes) retg = DGLGraph(gidx, g.ntypes, g.etypes)
# Slice node features # Slice node features
for ntid, ntype in enumerate(g.ntypes): for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid] stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items(): for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid+num_nodes[ntid]) subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats retg.nodes[ntype].data[key] = subfeats
if store_ids: if store_ids:
retg.nodes[ntype].data[NID] = F.arange(stnid, stnid+num_nodes[ntid], retg.nodes[ntype].data[NID] = F.arange(
retg.idtype, retg.device) stnid, stnid + num_nodes[ntid], retg.idtype, retg.device
)
# Slice edge features # Slice edge features
for etid, etype in enumerate(g.canonical_etypes): for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid] steid = start_eid[etid]
for key, feat in g.edges[etype].data.items(): for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid+num_edges[etid]) subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])
retg.edges[etype].data[key] = subfeats retg.edges[etype].data[key] = subfeats
if store_ids: if store_ids:
retg.edges[etype].data[EID] = F.arange(steid, steid+num_edges[etid], retg.edges[etype].data[EID] = F.arange(
retg.idtype, retg.device) steid, steid + num_edges[etid], retg.idtype, retg.device
)
return retg return retg
"""Module for converting graph from/to other object.""" """Module for converting graph from/to other object."""
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from scipy.sparse import spmatrix
import numpy as np
import networkx as nx import networkx as nx
import numpy as np
from scipy.sparse import spmatrix
from . import backend as F from . import backend as F, graph_index, heterograph_index, utils
from . import heterograph_index from .base import DGLError, EID, ETYPE, NID, NTYPE
from .heterograph import DGLGraph, combine_frames, DGLBlock from .heterograph import combine_frames, DGLBlock, DGLGraph
from . import graph_index
from . import utils
from .base import NTYPE, ETYPE, NID, EID, DGLError
__all__ = [ __all__ = [
'graph', "graph",
'hetero_from_shared_memory', "hetero_from_shared_memory",
'heterograph', "heterograph",
'create_block', "create_block",
'block_to_graph', "block_to_graph",
'to_heterogeneous', "to_heterogeneous",
'to_homogeneous', "to_homogeneous",
'from_scipy', "from_scipy",
'bipartite_from_scipy', "bipartite_from_scipy",
'from_networkx', "from_networkx",
'bipartite_from_networkx', "bipartite_from_networkx",
'to_networkx', "to_networkx",
'from_cugraph', "from_cugraph",
'to_cugraph' "to_cugraph",
] ]
def graph(data,
*, def graph(
num_nodes=None, data,
idtype=None, *,
device=None, num_nodes=None,
row_sorted=False, idtype=None,
col_sorted=False): device=None,
row_sorted=False,
col_sorted=False,
):
"""Create a graph and return. """Create a graph and return.
Parameters Parameters
...@@ -147,25 +148,41 @@ def graph(data, ...@@ -147,25 +148,41 @@ def graph(data,
from_networkx from_networkx
""" """
if isinstance(data, spmatrix): if isinstance(data, spmatrix):
raise DGLError("dgl.graph no longer supports graph construction from a SciPy " raise DGLError(
"sparse matrix, use dgl.from_scipy instead.") "dgl.graph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if isinstance(data, nx.Graph): if isinstance(data, nx.Graph):
raise DGLError("dgl.graph no longer supports graph construction from a NetworkX " raise DGLError(
"graph, use dgl.from_networkx instead.") "dgl.graph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead."
)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype) (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)
if num_nodes is not None: # override the number of nodes if num_nodes is not None: # override the number of nodes
if num_nodes < max(urange, vrange): if num_nodes < max(urange, vrange):
raise DGLError('The num_nodes argument must be larger than the max ID in the data,' raise DGLError(
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1)) "The num_nodes argument must be larger than the max ID in the data,"
" but got {} and {}.".format(num_nodes, max(urange, vrange) - 1)
)
urange, vrange = num_nodes, num_nodes urange, vrange = num_nodes, num_nodes
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange, g = create_from_edges(
row_sorted=row_sorted, col_sorted=col_sorted) sparse_fmt,
arrays,
"_N",
"_E",
"_N",
urange,
vrange,
row_sorted=row_sorted,
col_sorted=col_sorted,
)
return g.to(device) return g.to(device)
def hetero_from_shared_memory(name): def hetero_from_shared_memory(name):
"""Create a heterograph from shared memory with the given name. """Create a heterograph from shared memory with the given name.
...@@ -181,13 +198,13 @@ def hetero_from_shared_memory(name): ...@@ -181,13 +198,13 @@ def hetero_from_shared_memory(name):
------- -------
HeteroGraph (in shared memory) HeteroGraph (in shared memory)
""" """
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(name) g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(
name
)
return DGLGraph(g, ntypes, etypes) return DGLGraph(g, ntypes, etypes)
def heterograph(data_dict,
num_nodes_dict=None, def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
idtype=None,
device=None):
"""Create a heterogeneous graph and return. """Create a heterogeneous graph and return.
Parameters Parameters
...@@ -300,47 +317,77 @@ def heterograph(data_dict, ...@@ -300,47 +317,77 @@ def heterograph(data_dict,
num_nodes_dict = defaultdict(int) num_nodes_dict = defaultdict(int)
for (sty, ety, dty), data in data_dict.items(): for (sty, ety, dty), data in data_dict.items():
if isinstance(data, spmatrix): if isinstance(data, spmatrix):
raise DGLError("dgl.heterograph no longer supports graph construction from a SciPy " raise DGLError(
"sparse matrix, use dgl.from_scipy instead.") "dgl.heterograph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if isinstance(data, nx.Graph): if isinstance(data, nx.Graph):
raise DGLError("dgl.heterograph no longer supports graph construction from a NetworkX " raise DGLError(
"graph, use dgl.from_networkx instead.") "dgl.heterograph no longer supports graph construction from a NetworkX "
is_bipartite = (sty != dty) "graph, use dgl.from_networkx instead."
)
is_bipartite = sty != dty
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors( (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=is_bipartite) data, idtype, bipartite=is_bipartite
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays) node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer: if need_infer:
num_nodes_dict[sty] = max(num_nodes_dict[sty], urange) num_nodes_dict[sty] = max(num_nodes_dict[sty], urange)
num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange) num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange)
else: # sanity check else: # sanity check
if num_nodes_dict[sty] < urange: if num_nodes_dict[sty] < urange:
raise DGLError('The given number of nodes of node type {} must be larger than' raise DGLError(
' the max ID in the data, but got {} and {}.'.format( "The given number of nodes of node type {} must be larger than"
sty, num_nodes_dict[sty], urange - 1)) " the max ID in the data, but got {} and {}.".format(
sty, num_nodes_dict[sty], urange - 1
)
)
if num_nodes_dict[dty] < vrange: if num_nodes_dict[dty] < vrange:
raise DGLError('The given number of nodes of node type {} must be larger than' raise DGLError(
' the max ID in the data, but got {} and {}.'.format( "The given number of nodes of node type {} must be larger than"
dty, num_nodes_dict[dty], vrange - 1)) " the max ID in the data, but got {} and {}.".format(
dty, num_nodes_dict[dty], vrange - 1
)
)
# Create the graph # Create the graph
metagraph, ntypes, etypes, relations = heterograph_index.create_metagraph_index( (
num_nodes_dict.keys(), node_tensor_dict.keys()) metagraph,
num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64") ntypes,
etypes,
relations,
) = heterograph_index.create_metagraph_index(
num_nodes_dict.keys(), node_tensor_dict.keys()
)
num_nodes_per_type = utils.toindex(
[num_nodes_dict[ntype] for ntype in ntypes], "int64"
)
rel_graphs = [] rel_graphs = []
for srctype, etype, dsttype in relations: for srctype, etype, dsttype in relations:
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)] sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, srctype, etype, dsttype, g = create_from_edges(
num_nodes_dict[srctype], num_nodes_dict[dsttype]) sparse_fmt,
arrays,
srctype,
etype,
dsttype,
num_nodes_dict[srctype],
num_nodes_dict[dsttype],
)
rel_graphs.append(g) rel_graphs.append(g)
# create graph index # create graph index
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type) metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type
)
retg = DGLGraph(hgidx, ntypes, etypes) retg = DGLGraph(hgidx, ntypes, etypes)
return retg.to(device) return retg.to(device)
def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None):
def create_block(
data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object. """Create a message flow graph (MFG) as a :class:`DGLBlock` object.
Parameters Parameters
...@@ -464,21 +511,25 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, ...@@ -464,21 +511,25 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
""" """
need_infer = num_src_nodes is None and num_dst_nodes is None need_infer = num_src_nodes is None and num_dst_nodes is None
if not isinstance(data_dict, Mapping): if not isinstance(data_dict, Mapping):
data_dict = {('_N', '_E', '_N'): data_dict} data_dict = {("_N", "_E", "_N"): data_dict}
if not need_infer: if not need_infer:
assert isinstance(num_src_nodes, int), \ assert isinstance(
"num_src_nodes must be a pair of integers if data_dict is not a dict" num_src_nodes, int
assert isinstance(num_dst_nodes, int), \ ), "num_src_nodes must be a pair of integers if data_dict is not a dict"
"num_dst_nodes must be a pair of integers if data_dict is not a dict" assert isinstance(
num_src_nodes = {'_N': num_src_nodes} num_dst_nodes, int
num_dst_nodes = {'_N': num_dst_nodes} ), "num_dst_nodes must be a pair of integers if data_dict is not a dict"
num_src_nodes = {"_N": num_src_nodes}
num_dst_nodes = {"_N": num_dst_nodes}
else: else:
if not need_infer: if not need_infer:
assert isinstance(num_src_nodes, Mapping), \ assert isinstance(
"num_src_nodes must be a dict if data_dict is a dict" num_src_nodes, Mapping
assert isinstance(num_dst_nodes, Mapping), \ ), "num_src_nodes must be a dict if data_dict is a dict"
"num_dst_nodes must be a dict if data_dict is a dict" assert isinstance(
num_dst_nodes, Mapping
), "num_dst_nodes must be a dict if data_dict is a dict"
if need_infer: if need_infer:
num_src_nodes = defaultdict(int) num_src_nodes = defaultdict(int)
...@@ -488,20 +539,27 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, ...@@ -488,20 +539,27 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
node_tensor_dict = {} node_tensor_dict = {}
for (sty, ety, dty), data in data_dict.items(): for (sty, ety, dty), data in data_dict.items():
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors( (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=True) data, idtype, bipartite=True
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays) node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer: if need_infer:
num_src_nodes[sty] = max(num_src_nodes[sty], urange) num_src_nodes[sty] = max(num_src_nodes[sty], urange)
num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange) num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)
else: # sanity check else: # sanity check
if num_src_nodes[sty] < urange: if num_src_nodes[sty] < urange:
raise DGLError('The given number of nodes of source node type {} must be larger' raise DGLError(
' than the max ID in the data, but got {} and {}.'.format( "The given number of nodes of source node type {} must be larger"
sty, num_src_nodes[sty], urange - 1)) " than the max ID in the data, but got {} and {}.".format(
sty, num_src_nodes[sty], urange - 1
)
)
if num_dst_nodes[dty] < vrange: if num_dst_nodes[dty] < vrange:
raise DGLError('The given number of nodes of destination node type {} must be' raise DGLError(
' larger than the max ID in the data, but got {} and {}.'.format( "The given number of nodes of destination node type {} must be"
dty, num_dst_nodes[dty], vrange - 1)) " larger than the max ID in the data, but got {} and {}.".format(
dty, num_dst_nodes[dty], vrange - 1
)
)
# Create the graph # Create the graph
# Sort the ntypes and relation tuples to have a deterministic order for the same set # Sort the ntypes and relation tuples to have a deterministic order for the same set
...@@ -511,10 +569,14 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, ...@@ -511,10 +569,14 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
relations = list(sorted(node_tensor_dict.keys())) relations = list(sorted(node_tensor_dict.keys()))
num_nodes_per_type = utils.toindex( num_nodes_per_type = utils.toindex(
[num_src_nodes[ntype] for ntype in srctypes] + [num_src_nodes[ntype] for ntype in srctypes]
[num_dst_nodes[ntype] for ntype in dsttypes], "int64") + [num_dst_nodes[ntype] for ntype in dsttypes],
"int64",
)
srctype_dict = {ntype: i for i, ntype in enumerate(srctypes)} srctype_dict = {ntype: i for i, ntype in enumerate(srctypes)}
dsttype_dict = {ntype: i + len(srctypes) for i, ntype in enumerate(dsttypes)} dsttype_dict = {
ntype: i + len(srctypes) for i, ntype in enumerate(dsttypes)
}
meta_edges_src = [] meta_edges_src = []
meta_edges_dst = [] meta_edges_dst = []
...@@ -525,20 +587,30 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, ...@@ -525,20 +587,30 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
meta_edges_dst.append(dsttype_dict[dsttype]) meta_edges_dst.append(dsttype_dict[dsttype])
etypes.append(etype) etypes.append(etype)
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)] sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, 'SRC/' + srctype, etype, 'DST/' + dsttype, g = create_from_edges(
num_src_nodes[srctype], num_dst_nodes[dsttype]) sparse_fmt,
arrays,
"SRC/" + srctype,
etype,
"DST/" + dsttype,
num_src_nodes[srctype],
num_dst_nodes[dsttype],
)
rel_graphs.append(g) rel_graphs.append(g)
# metagraph is DGLGraph, currently still using int64 as index dtype # metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = graph_index.from_coo( metagraph = graph_index.from_coo(
len(srctypes) + len(dsttypes), meta_edges_src, meta_edges_dst, True) len(srctypes) + len(dsttypes), meta_edges_src, meta_edges_dst, True
)
# create graph index # create graph index
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type) metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type
)
retg = DGLBlock(hgidx, (srctypes, dsttypes), etypes) retg = DGLBlock(hgidx, (srctypes, dsttypes), etypes)
return retg.to(device) return retg.to(device)
def block_to_graph(block): def block_to_graph(block):
"""Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`. """Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`.
...@@ -568,22 +640,26 @@ def block_to_graph(block): ...@@ -568,22 +640,26 @@ def block_to_graph(block):
num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2}, num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2},
metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')]) metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')])
""" """
new_types = [ntype + '_src' for ntype in block.srctypes] + \ new_types = [ntype + "_src" for ntype in block.srctypes] + [
[ntype + '_dst' for ntype in block.dsttypes] ntype + "_dst" for ntype in block.dsttypes
]
retg = DGLGraph(block._graph, new_types, block.etypes) retg = DGLGraph(block._graph, new_types, block.etypes)
for srctype in block.srctypes: for srctype in block.srctypes:
retg.nodes[srctype + '_src'].data.update(block.srcnodes[srctype].data) retg.nodes[srctype + "_src"].data.update(block.srcnodes[srctype].data)
for dsttype in block.dsttypes: for dsttype in block.dsttypes:
retg.nodes[dsttype + '_dst'].data.update(block.dstnodes[dsttype].data) retg.nodes[dsttype + "_dst"].data.update(block.dstnodes[dsttype].data)
for srctype, etype, dsttype in block.canonical_etypes: for srctype, etype, dsttype in block.canonical_etypes:
retg.edges[srctype + '_src', etype, dsttype + '_dst'].data.update( retg.edges[srctype + "_src", etype, dsttype + "_dst"].data.update(
block.edges[srctype, etype, dsttype].data) block.edges[srctype, etype, dsttype].data
)
return retg return retg
def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
etype_field=ETYPE, metagraph=None): def to_heterogeneous(
G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None
):
"""Convert a homogeneous graph to a heterogeneous graph and return. """Convert a homogeneous graph to a heterogeneous graph and return.
The input graph should have only one type of nodes and edges. Each node and edge The input graph should have only one type of nodes and edges. Each node and edge
...@@ -691,10 +767,16 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE, ...@@ -691,10 +767,16 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
-------- --------
to_homogeneous to_homogeneous
""" """
if (hasattr(G, 'ntypes') and len(G.ntypes) > 1 if (
or hasattr(G, 'etypes') and len(G.etypes) > 1): hasattr(G, "ntypes")
raise DGLError('The input graph should be homogeneous and have only one ' and len(G.ntypes) > 1
' type of nodes and edges.') or hasattr(G, "etypes")
and len(G.etypes) > 1
):
raise DGLError(
"The input graph should be homogeneous and have only one "
" type of nodes and edges."
)
num_ntypes = len(ntypes) num_ntypes = len(ntypes)
idtype = G.idtype idtype = G.idtype
...@@ -706,15 +788,15 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE, ...@@ -706,15 +788,15 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# relabel nodes to per-type local IDs # relabel nodes to per-type local IDs
ntype_count = np.bincount(ntype_ids, minlength=num_ntypes) ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)
ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0) ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0)
ntype_ids_sortidx = np.argsort(ntype_ids, kind='stable') ntype_ids_sortidx = np.argsort(ntype_ids, kind="stable")
ntype_local_ids = np.zeros_like(ntype_ids) ntype_local_ids = np.zeros_like(ntype_ids)
node_groups = [] node_groups = []
for i in range(num_ntypes): for i in range(num_ntypes):
node_group = ntype_ids_sortidx[ntype_offset[i]:ntype_offset[i+1]] node_group = ntype_ids_sortidx[ntype_offset[i] : ntype_offset[i + 1]]
node_groups.append(node_group) node_groups.append(node_group)
ntype_local_ids[node_group] = np.arange(ntype_count[i]) ntype_local_ids[node_group] = np.arange(ntype_count[i])
src, dst = G.all_edges(order='eid') src, dst = G.all_edges(order="eid")
src = F.asnumpy(src) src = F.asnumpy(src)
dst = F.asnumpy(dst) dst = F.asnumpy(dst)
src_local = ntype_local_ids[src] src_local = ntype_local_ids[src]
...@@ -729,21 +811,28 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE, ...@@ -729,21 +811,28 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the # above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the
# canonical edge type j. We can then group the edges of the same type together. # canonical edge type j. We can then group the edges of the same type together.
if metagraph is None: if metagraph is None:
canonical_etids, _, etype_remapped = \ canonical_etids, _, etype_remapped = utils.make_invmap(
utils.make_invmap(list(tuple(_) for _ in edge_ctids), False) list(tuple(_) for _ in edge_ctids), False
etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None]) )
etype_mask = (
etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None]
)
else: else:
ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)} ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}
etypes_invmap = {et: i for i, et in enumerate(etypes)} etypes_invmap = {et: i for i, et in enumerate(etypes)}
canonical_etids = [] canonical_etids = []
for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)): for i, (srctype, dsttype, etype) in enumerate(
metagraph.edges(keys=True)
):
srctype_id = ntypes_invmap[srctype] srctype_id = ntypes_invmap[srctype]
etype_id = etypes_invmap[etype] etype_id = etypes_invmap[etype]
dsttype_id = ntypes_invmap[dsttype] dsttype_id = ntypes_invmap[dsttype]
canonical_etids.append((srctype_id, etype_id, dsttype_id)) canonical_etids.append((srctype_id, etype_id, dsttype_id))
canonical_etids = np.asarray(canonical_etids) canonical_etids = np.asarray(canonical_etids)
etype_mask = (edge_ctids[None, :] == canonical_etids[:, None]).all(2) etype_mask = (edge_ctids[None, :] == canonical_etids[:, None]).all(2)
edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))] edge_groups = [
etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))
]
data_dict = dict() data_dict = dict()
canonical_etypes = [] canonical_etypes = []
...@@ -751,13 +840,12 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE, ...@@ -751,13 +840,12 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
src_of_etype = src_local[edge_groups[i]] src_of_etype = src_local[edge_groups[i]]
dst_of_etype = dst_local[edge_groups[i]] dst_of_etype = dst_local[edge_groups[i]]
canonical_etypes.append((ntypes[stid], etypes[etid], ntypes[dtid])) canonical_etypes.append((ntypes[stid], etypes[etid], ntypes[dtid]))
data_dict[canonical_etypes[-1]] = \ data_dict[canonical_etypes[-1]] = (src_of_etype, dst_of_etype)
(src_of_etype, dst_of_etype) hg = heterograph(
hg = heterograph(data_dict, data_dict, dict(zip(ntypes, ntype_count)), idtype=idtype, device=device
dict(zip(ntypes, ntype_count)), )
idtype=idtype, device=device)
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)} ntype2ngrp = {ntype: node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
# features # features
for key, data in G.ndata.items(): for key, data in G.ndata.items():
...@@ -772,19 +860,26 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE, ...@@ -772,19 +860,26 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
continue continue
for etid in range(len(hg.canonical_etypes)): for etid in range(len(hg.canonical_etypes)):
rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data)) rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][key] = \ hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][
F.gather_row(data, rows) key
] = F.gather_row(data, rows)
# Record the original IDs of the nodes/edges # Record the original IDs of the nodes/edges
for ntid, ntype in enumerate(hg.ntypes): for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.copy_to(F.tensor(ntype2ngrp[ntype]), device) hg._node_frames[ntid][NID] = F.copy_to(
F.tensor(ntype2ngrp[ntype]), device
)
for etid in range(len(hg.canonical_etypes)): for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][EID] = \ hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][
F.copy_to(F.tensor(edge_groups[etid]), device) EID
] = F.copy_to(F.tensor(edge_groups[etid]), device)
return hg return hg
def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=False):
def to_homogeneous(
G, ndata=None, edata=None, store_type=True, return_count=False
):
"""Convert a heterogeneous graph to a homogeneous graph and return. """Convert a heterogeneous graph to a homogeneous graph and return.
By default, the function stores the node and edge types of the input graph as By default, the function stores the node and edge types of the input graph as
...@@ -902,7 +997,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals ...@@ -902,7 +997,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
for etype_id, etype in enumerate(G.canonical_etypes): for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype srctype, _, dsttype = etype
src, dst = G.all_edges(etype=etype, order='eid') src, dst = G.all_edges(etype=etype, order="eid")
num_edges = len(src) num_edges = len(src)
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)])) srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)])) dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
...@@ -913,16 +1008,24 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals ...@@ -913,16 +1008,24 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
etype_count.append(num_edges) etype_count.append(num_edges)
eids.append(F.arange(0, num_edges, G.idtype, G.device)) eids.append(F.arange(0, num_edges, G.idtype, G.device))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes, retg = graph(
idtype=G.idtype, device=G.device) (F.cat(srcs, 0), F.cat(dsts, 0)),
num_nodes=total_num_nodes,
idtype=G.idtype,
device=G.device,
)
# copy features # copy features
if ndata is None: if ndata is None:
ndata = [] ndata = []
if edata is None: if edata is None:
edata = [] edata = []
comb_nf = combine_frames(G._node_frames, range(len(G.ntypes)), col_names=ndata) comb_nf = combine_frames(
comb_ef = combine_frames(G._edge_frames, range(len(G.etypes)), col_names=edata) G._node_frames, range(len(G.ntypes)), col_names=ndata
)
comb_ef = combine_frames(
G._edge_frames, range(len(G.etypes)), col_names=edata
)
if comb_nf is not None: if comb_nf is not None:
retg.ndata.update(comb_nf) retg.ndata.update(comb_nf)
if comb_ef is not None: if comb_ef is not None:
...@@ -939,10 +1042,8 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals ...@@ -939,10 +1042,8 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
else: else:
return retg return retg
def from_scipy(sp_mat,
eweight_name=None, def from_scipy(sp_mat, eweight_name=None, idtype=None, device=None):
idtype=None,
device=None):
"""Create a graph from a SciPy sparse matrix and return. """Create a graph from a SciPy sparse matrix and return.
Parameters Parameters
...@@ -1019,20 +1120,23 @@ def from_scipy(sp_mat, ...@@ -1019,20 +1120,23 @@ def from_scipy(sp_mat,
num_rows = sp_mat.shape[0] num_rows = sp_mat.shape[0]
num_cols = sp_mat.shape[1] num_cols = sp_mat.shape[1]
if num_rows != num_cols: if num_rows != num_cols:
raise DGLError('Expect the number of rows to be the same as the number of columns for ' raise DGLError(
'sp_mat, got {:d} and {:d}.'.format(num_rows, num_cols)) "Expect the number of rows to be the same as the number of columns for "
"sp_mat, got {:d} and {:d}.".format(num_rows, num_cols)
)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype) (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange) sp_mat, idtype
)
g = create_from_edges(sparse_fmt, arrays, "_N", "_E", "_N", urange, vrange)
if eweight_name is not None: if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data) g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device) return g.to(device)
def bipartite_from_scipy(sp_mat,
utype, etype, vtype, def bipartite_from_scipy(
eweight_name=None, sp_mat, utype, etype, vtype, eweight_name=None, idtype=None, device=None
idtype=None, ):
device=None):
"""Create a uni-directional bipartite graph from a SciPy sparse matrix and return. """Create a uni-directional bipartite graph from a SciPy sparse matrix and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
...@@ -1116,18 +1220,25 @@ def bipartite_from_scipy(sp_mat, ...@@ -1116,18 +1220,25 @@ def bipartite_from_scipy(sp_mat,
heterograph heterograph
bipartite_from_networkx bipartite_from_networkx
""" """
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype, bipartite=True) (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange) sp_mat, idtype, bipartite=True
)
g = create_from_edges(
sparse_fmt, arrays, utype, etype, vtype, urange, vrange
)
if eweight_name is not None: if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data) g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device) return g.to(device)
def from_networkx(nx_graph,
node_attrs=None, def from_networkx(
edge_attrs=None, nx_graph,
edge_id_attr_name=None, node_attrs=None,
idtype=None, edge_attrs=None,
device=None): edge_id_attr_name=None,
idtype=None,
device=None,
):
"""Create a graph from a NetworkX graph and return. """Create a graph from a NetworkX graph and return.
.. note:: .. note::
...@@ -1221,27 +1332,38 @@ def from_networkx(nx_graph, ...@@ -1221,27 +1332,38 @@ def from_networkx(nx_graph,
from_scipy from_scipy
""" """
# Sanity check # Sanity check
if edge_id_attr_name is not None and \ if (
edge_id_attr_name not in next(iter(nx_graph.edges(data=True)))[-1]: edge_id_attr_name is not None
raise DGLError('Failed to find the pre-specified edge IDs in the edge features of ' and edge_id_attr_name not in next(iter(nx_graph.edges(data=True)))[-1]
'the NetworkX graph with name {}'.format(edge_id_attr_name)) ):
raise DGLError(
if not nx_graph.is_directed() and not (edge_id_attr_name is None and edge_attrs is None): "Failed to find the pre-specified edge IDs in the edge features of "
raise DGLError('Expect edge_id_attr_name and edge_attrs to be None when nx_graph is ' "the NetworkX graph with name {}".format(edge_id_attr_name)
'undirected, got {} and {}'.format(edge_id_attr_name, edge_attrs)) )
if not nx_graph.is_directed() and not (
edge_id_attr_name is None and edge_attrs is None
):
raise DGLError(
"Expect edge_id_attr_name and edge_attrs to be None when nx_graph is "
"undirected, got {} and {}".format(edge_id_attr_name, edge_attrs)
)
# Relabel nodes using consecutive integers starting from 0 # Relabel nodes using consecutive integers starting from 0
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted') nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering="sorted")
if not nx_graph.is_directed(): if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed() nx_graph = nx_graph.to_directed()
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors( (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, edge_id_attr_name=edge_id_attr_name) nx_graph, idtype, edge_id_attr_name=edge_id_attr_name
)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange) g = create_from_edges(sparse_fmt, arrays, "_N", "_E", "_N", urange, vrange)
# nx_graph.edges(data=True) returns src, dst, attr_dict # nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None has_edge_id = (
nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
)
# handle features # handle features
# copy attributes # copy attributes
...@@ -1250,6 +1372,7 @@ def from_networkx(nx_graph, ...@@ -1250,6 +1372,7 @@ def from_networkx(nx_graph,
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0) return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else: else:
return F.tensor(lst) return F.tensor(lst)
if node_attrs is not None: if node_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated # mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(list) attr_dict = defaultdict(list)
...@@ -1269,9 +1392,11 @@ def from_networkx(nx_graph, ...@@ -1269,9 +1392,11 @@ def from_networkx(nx_graph,
num_edges = g.number_of_edges() num_edges = g.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True): for _, _, attrs in nx_graph.edges(data=True):
if attrs[edge_id_attr_name] >= num_edges: if attrs[edge_id_attr_name] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be' raise DGLError(
' smaller than the number of edges --' "Expect the pre-specified edge ids to be"
' {}, got {}.'.format(num_edges, attrs['id'])) " smaller than the number of edges --"
" {}, got {}.".format(num_edges, attrs["id"])
)
for key in edge_attrs: for key in edge_attrs:
attr_dict[key][attrs[edge_id_attr_name]] = attrs[key] attr_dict[key][attrs[edge_id_attr_name]] = attrs[key]
else: else:
...@@ -1283,17 +1408,26 @@ def from_networkx(nx_graph, ...@@ -1283,17 +1408,26 @@ def from_networkx(nx_graph,
for attr in edge_attrs: for attr in edge_attrs:
for val in attr_dict[attr]: for val in attr_dict[attr]:
if val is None: if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr)) raise DGLError(
"Not all edges have attribute {}.".format(attr)
)
g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device) g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)
return g.to(device) return g.to(device)
def bipartite_from_networkx(nx_graph,
utype, etype, vtype, def bipartite_from_networkx(
u_attrs=None, e_attrs=None, v_attrs=None, nx_graph,
edge_id_attr_name=None, utype,
idtype=None, etype,
device=None): vtype,
u_attrs=None,
e_attrs=None,
v_attrs=None,
edge_id_attr_name=None,
idtype=None,
device=None,
):
"""Create a unidirectional bipartite graph from a NetworkX graph and return. """Create a unidirectional bipartite graph from a NetworkX graph and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
...@@ -1403,42 +1537,58 @@ def bipartite_from_networkx(nx_graph, ...@@ -1403,42 +1537,58 @@ def bipartite_from_networkx(nx_graph,
bipartite_from_scipy bipartite_from_scipy
""" """
if not nx_graph.is_directed(): if not nx_graph.is_directed():
raise DGLError('Expect nx_graph to be a directed NetworkX graph.') raise DGLError("Expect nx_graph to be a directed NetworkX graph.")
if edge_id_attr_name is not None and \ if (
not edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]: edge_id_attr_name is not None
raise DGLError('Failed to find the pre-specified edge IDs in the edge features ' and not edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
'of the NetworkX graph with name {}'.format(edge_id_attr_name)) ):
raise DGLError(
"Failed to find the pre-specified edge IDs in the edge features "
"of the NetworkX graph with name {}".format(edge_id_attr_name)
)
# Get the source and destination node sets # Get the source and destination node sets
top_nodes = set() top_nodes = set()
bottom_nodes = set() bottom_nodes = set()
for n, ndata in nx_graph.nodes(data=True): for n, ndata in nx_graph.nodes(data=True):
if 'bipartite' not in ndata: if "bipartite" not in ndata:
raise DGLError('Expect the node {} to have attribute bipartite'.format(n)) raise DGLError(
if ndata['bipartite'] == 0: "Expect the node {} to have attribute bipartite".format(n)
)
if ndata["bipartite"] == 0:
top_nodes.add(n) top_nodes.add(n)
elif ndata['bipartite'] == 1: elif ndata["bipartite"] == 1:
bottom_nodes.add(n) bottom_nodes.add(n)
else: else:
raise ValueError('Expect the bipartite attribute of the node {} to be 0 or 1, ' raise ValueError(
'got {}'.format(n, ndata['bipartite'])) "Expect the bipartite attribute of the node {} to be 0 or 1, "
"got {}".format(n, ndata["bipartite"])
)
# Separately relabel the source and destination nodes. # Separately relabel the source and destination nodes.
top_nodes = sorted(top_nodes) top_nodes = sorted(top_nodes)
bottom_nodes = sorted(bottom_nodes) bottom_nodes = sorted(bottom_nodes)
top_map = {n : i for i, n in enumerate(top_nodes)} top_map = {n: i for i, n in enumerate(top_nodes)}
bottom_map = {n : i for i, n in enumerate(bottom_nodes)} bottom_map = {n: i for i, n in enumerate(bottom_nodes)}
# Get the node tensors and the number of nodes # Get the node tensors and the number of nodes
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors( (sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, bipartite=True, nx_graph,
idtype,
bipartite=True,
edge_id_attr_name=edge_id_attr_name, edge_id_attr_name=edge_id_attr_name,
top_map=top_map, bottom_map=bottom_map) top_map=top_map,
bottom_map=bottom_map,
)
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange) g = create_from_edges(
sparse_fmt, arrays, utype, etype, vtype, urange, vrange
)
# nx_graph.edges(data=True) returns src, dst, attr_dict # nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None has_edge_id = (
nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
)
# handle features # handle features
# copy attributes # copy attributes
...@@ -1485,11 +1635,14 @@ def bipartite_from_networkx(nx_graph, ...@@ -1485,11 +1635,14 @@ def bipartite_from_networkx(nx_graph,
for attr in e_attrs: for attr in e_attrs:
for val in attr_dict[attr]: for val in attr_dict[attr]:
if val is None: if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr)) raise DGLError(
"Not all edges have attribute {}.".format(attr)
)
g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device) g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)
return g.to(device) return g.to(device)
def to_networkx(g, node_attrs=None, edge_attrs=None): def to_networkx(g, node_attrs=None, edge_attrs=None):
"""Convert a homogeneous graph to a NetworkX graph and return. """Convert a homogeneous graph to a NetworkX graph and return.
...@@ -1537,9 +1690,11 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): ...@@ -1537,9 +1690,11 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})]) (2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})])
""" """
if g.device != F.cpu(): if g.device != F.cpu():
raise DGLError('Cannot convert a CUDA graph to networkx. Call g.cpu() first.') raise DGLError(
"Cannot convert a CUDA graph to networkx. Call g.cpu() first."
)
if not g.is_homogeneous: if not g.is_homogeneous:
raise DGLError('dgl.to_networkx only supports homogeneous graphs.') raise DGLError("dgl.to_networkx only supports homogeneous graphs.")
src, dst = g.edges() src, dst = g.edges()
src = F.asnumpy(src) src = F.asnumpy(src)
dst = F.asnumpy(dst) dst = F.asnumpy(dst)
...@@ -1552,16 +1707,22 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): ...@@ -1552,16 +1707,22 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
if node_attrs is not None: if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True): for nid, attr in nx_graph.nodes(data=True):
feat_dict = g._get_n_repr(0, nid) feat_dict = g._get_n_repr(0, nid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs}) attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in node_attrs}
)
if edge_attrs is not None: if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True): for _, _, attr in nx_graph.edges(data=True):
eid = attr['id'] eid = attr["id"]
feat_dict = g._get_e_repr(0, eid) feat_dict = g._get_e_repr(0, eid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}) attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}
)
return nx_graph return nx_graph
DGLGraph.to_networkx = to_networkx DGLGraph.to_networkx = to_networkx
def to_cugraph(g): def to_cugraph(g):
"""Convert a DGL graph to a :class:`cugraph.Graph` and return. """Convert a DGL graph to a :class:`cugraph.Graph` and return.
...@@ -1595,30 +1756,36 @@ def to_cugraph(g): ...@@ -1595,30 +1756,36 @@ def to_cugraph(g):
1 1 1 1 1 1
""" """
if g.device.type != 'cuda': if g.device.type != "cuda":
raise DGLError(f"Cannot convert a {g.device.type} graph to cugraph." + raise DGLError(
"Call g.to('cuda') first.") f"Cannot convert a {g.device.type} graph to cugraph."
+ "Call g.to('cuda') first."
)
if not g.is_homogeneous: if not g.is_homogeneous:
raise DGLError("dgl.to_cugraph only supports homogeneous graphs.") raise DGLError("dgl.to_cugraph only supports homogeneous graphs.")
try: try:
import cugraph
import cudf import cudf
import cugraph
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError("to_cugraph requires cugraph which could not be imported") raise ModuleNotFoundError(
"to_cugraph requires cugraph which could not be imported"
)
edgelist = g.edges() edgelist = g.edges()
src_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[0])) src_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[0]))
dst_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[1])) dst_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[1]))
cudf_data = cudf.DataFrame({'source':src_ser, 'destination':dst_ser}) cudf_data = cudf.DataFrame({"source": src_ser, "destination": dst_ser})
g_cugraph = cugraph.Graph(directed=True) g_cugraph = cugraph.Graph(directed=True)
g_cugraph.from_cudf_edgelist(cudf_data, g_cugraph.from_cudf_edgelist(
source='source', cudf_data, source="source", destination="destination"
destination='destination') )
return g_cugraph return g_cugraph
DGLGraph.to_cugraph = to_cugraph DGLGraph.to_cugraph = to_cugraph
def from_cugraph(cugraph_graph): def from_cugraph(cugraph_graph):
"""Create a graph from a :class:`cugraph.Graph` object. """Create a graph from a :class:`cugraph.Graph` object.
...@@ -1660,21 +1827,29 @@ def from_cugraph(cugraph_graph): ...@@ -1660,21 +1827,29 @@ def from_cugraph(cugraph_graph):
cugraph_graph = cugraph_graph.to_directed() cugraph_graph = cugraph_graph.to_directed()
edges = cugraph_graph.edges() edges = cugraph_graph.edges()
src_t = F.zerocopy_from_dlpack(edges['src'].to_dlpack()) src_t = F.zerocopy_from_dlpack(edges["src"].to_dlpack())
dst_t = F.zerocopy_from_dlpack(edges['dst'].to_dlpack()) dst_t = F.zerocopy_from_dlpack(edges["dst"].to_dlpack())
g = graph((src_t,dst_t)) g = graph((src_t, dst_t))
return g return g
############################################################ ############################################################
# Internal APIs # Internal APIs
############################################################ ############################################################
def create_from_edges(sparse_fmt, arrays,
utype, etype, vtype, def create_from_edges(
urange, vrange, sparse_fmt,
row_sorted=False, arrays,
col_sorted=False): utype,
etype,
vtype,
urange,
vrange,
row_sorted=False,
col_sorted=False,
):
"""Internal function to create a graph from incident nodes with types. """Internal function to create a graph from incident nodes with types.
utype could be equal to vtype utype could be equal to vtype
...@@ -1713,16 +1888,30 @@ def create_from_edges(sparse_fmt, arrays, ...@@ -1713,16 +1888,30 @@ def create_from_edges(sparse_fmt, arrays,
else: else:
num_ntypes = 2 num_ntypes = 2
if sparse_fmt == 'coo': if sparse_fmt == "coo":
u, v = arrays u, v = arrays
hgidx = heterograph_index.create_unitgraph_from_coo( hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'], num_ntypes,
row_sorted, col_sorted) urange,
else: # 'csr' or 'csc' vrange,
u,
v,
["coo", "csr", "csc"],
row_sorted,
col_sorted,
)
else: # 'csr' or 'csc'
indptr, indices, eids = arrays indptr, indices, eids = arrays
hgidx = heterograph_index.create_unitgraph_from_csr( hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, urange, vrange, indptr, indices, eids, ['coo', 'csr', 'csc'], num_ntypes,
sparse_fmt == 'csc') urange,
vrange,
indptr,
indices,
eids,
["coo", "csr", "csc"],
sparse_fmt == "csc",
)
if utype == vtype: if utype == vtype:
return DGLGraph(hgidx, [utype], [etype]) return DGLGraph(hgidx, [utype], [etype])
else: else:
......
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
# pylint: disable=not-callable # pylint: disable=not-callable
import numpy as np import numpy as np
from . import backend as F from . import backend as F, function as fn, ops
from . import function as fn from .base import ALL, dgl_warning, DGLError, EID, is_all, NID
from . import ops
from .base import ALL, EID, NID, DGLError, dgl_warning, is_all
from .frame import Frame from .frame import Frame
from .udf import EdgeBatch, NodeBatch from .udf import EdgeBatch, NodeBatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment