Unverified Commit a566b60b authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent d1827488
......@@ -14,26 +14,26 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
else:
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import (
_set_class_function,
_set_class_module,
convert_to_dgl_func,
FunctionBase as _FunctionBase,
)
FunctionHandle = ctypes.c_void_p
......
......@@ -9,12 +9,12 @@ import numpy as np
from .base import _FFI_MODE, _LIB, c_array, c_str, check_call, string_types
from .runtime_ctypes import (
dgl_shape_index_t,
DGLArray,
DGLArrayHandle,
DGLContext,
DGLDataType,
TypeCode,
dgl_shape_index_t,
)
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -24,29 +24,29 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import NDArrayBase as _NDArrayBase
from ._cy3.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
else:
from ._cy2.core import NDArrayBase as _NDArrayBase
from ._cy2.core import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
from ._ctypes.ndarray import (
_from_dlpack,
_make_array,
_reg_extension,
_set_class_ndarray,
NDArrayBase as _NDArrayBase,
)
......
......@@ -7,7 +7,7 @@ import sys
from .. import _api_internal
from .base import _FFI_MODE, _LIB, c_str, check_call, py_str
from .object_generic import ObjectGeneric, convert_to_object
from .object_generic import convert_to_object, ObjectGeneric
# pylint: disable=invalid-name
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
......@@ -16,15 +16,12 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import _register_object
from ._cy3.core import _register_object, ObjectBase as _ObjectBase
else:
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import _register_object
from ._cy2.core import _register_object, ObjectBase as _ObjectBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.object import ObjectBase as _ObjectBase
from ._ctypes.object import _register_object
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
def _new_object(cls):
......
"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from . import backend as F
from .base import ALL, is_all, DGLError, NID, EID
from .heterograph_index import disjoint_union, slice_gidx
from . import backend as F, convert, utils
from .base import ALL, DGLError, EID, is_all, NID
from .heterograph import DGLGraph
from . import convert
from . import utils
from .heterograph_index import disjoint_union, slice_gidx
__all__ = ["batch", "unbatch", "slice_batch"]
__all__ = ['batch', 'unbatch', 'slice_batch']
def batch(graphs, ndata=ALL, edata=ALL):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
......@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
unbatch
"""
if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.')
raise DGLError("The input list of graphs cannot be empty.")
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata)))
raise DGLError(
"Invalid argument ndata: must be a string list but got {}.".format(
type(ndata)
)
)
if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format(
type(edata)))
raise DGLError(
"Invalid argument edata: must be a string list but got {}.".format(
type(edata)
)
)
if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.")
......@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs])
gidx = disjoint_union(
graphs[0]._graph.metagraph, [g._graph for g in graphs]
)
retg = DGLGraph(gidx, ntypes, etypes)
# Compute batch num nodes
......@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
# Batch node feature
if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.number_of_nodes(ntype_id) == 0 for g in graphs)
all_empty = all(
g._graph.number_of_nodes(ntype_id) == 0 for g in graphs
)
frames = [
g._node_frames[ntype_id] for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty]
g._node_frames[ntype_id]
for g in graphs
if g._graph.number_of_nodes(ntype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, ndata, 'nodes["{}"].data'.format(ntype))
ret_feat = _batch_feat_dicts(
frames, ndata, 'nodes["{}"].data'.format(ntype)
)
retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature
if edata is not None:
for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.number_of_edges(etype_id) == 0 for g in graphs)
all_empty = all(
g._graph.number_of_edges(etype_id) == 0 for g in graphs
)
frames = [
g._edge_frames[etype_id] for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty]
g._edge_frames[etype_id]
for g in graphs
if g._graph.number_of_edges(etype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(frames, edata, 'edges[{}].data'.format(etype))
ret_feat = _batch_feat_dicts(
frames, edata, "edges[{}].data".format(etype)
)
retg.edges[etype].data.update(ret_feat)
return retg
def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries.
......@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
else:
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features
ret_feat = {k : F.cat([fd[k] for fd in frames], 0) for k in keys}
ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat
def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones.
......@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
num_split = None
# Parse node_split
if node_split is None:
node_split = {ntype : g.batch_num_nodes(ntype) for ntype in g.ntypes}
node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1:
raise DGLError('Must provide a dictionary for argument node_split when'
' there are multiple node types.')
node_split = {g.ntypes[0] : node_split}
raise DGLError(
"Must provide a dictionary for argument node_split when"
" there are multiple node types."
)
node_split = {g.ntypes[0]: node_split}
if node_split.keys() != set(g.ntypes):
raise DGLError('Must specify node_split for each node type.')
raise DGLError("Must specify node_split for each node type.")
for split in node_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All node_split and edge_split must specify the same number'
' of split sizes.')
raise DGLError(
"All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
# Parse edge_split
if edge_split is None:
edge_split = {etype : g.batch_num_edges(etype) for etype in g.canonical_etypes}
edge_split = {
etype: g.batch_num_edges(etype) for etype in g.canonical_etypes
}
elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1:
raise DGLError('Must provide a dictionary for argument edge_split when'
' there are multiple edge types.')
edge_split = {g.canonical_etypes[0] : edge_split}
raise DGLError(
"Must provide a dictionary for argument edge_split when"
" there are multiple edge types."
)
edge_split = {g.canonical_etypes[0]: edge_split}
if edge_split.keys() != set(g.canonical_etypes):
raise DGLError('Must specify edge_split for each canonical edge type.')
raise DGLError("Must specify edge_split for each canonical edge type.")
for split in edge_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All edge_split and edge_split must specify the same number'
' of split sizes.')
raise DGLError(
"All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
node_split = {k : F.asnumpy(split).tolist() for k, split in node_split.items()}
edge_split = {k : F.asnumpy(split).tolist() for k, split in edge_split.items()}
node_split = {
k: F.asnumpy(split).tolist() for k, split in node_split.items()
}
edge_split = {
k: F.asnumpy(split).tolist() for k, split in edge_split.items()
}
# Split edges for each relation
edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
u, v = g.edges(order='eid', etype=rel)
u, v = g.edges(order="eid", etype=rel)
us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [{k : split[i] for k, split in node_split.items()}
for i in range(num_split)]
num_nodes_dict_per = [
{k: split[i] for k, split in node_split.items()}
for i in range(num_split)
]
# Create graphs
gs = [convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)]
gs = [
convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)
]
# Unbatch node features
for ntype in g.ntypes:
......@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
return gs
def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs.
......@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0:
start_nid.append(0)
else:
start_nid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0)))
start_nid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))
)
start_eid = []
num_edges = []
......@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
if gid == 0:
start_eid.append(0)
else:
start_eid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0)))
start_eid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))
)
# Slice graph structure
gidx = slice_gidx(g._graph, utils.toindex(num_nodes), utils.toindex(start_nid),
utils.toindex(num_edges), utils.toindex(start_eid))
gidx = slice_gidx(
g._graph,
utils.toindex(num_nodes),
utils.toindex(start_nid),
utils.toindex(num_edges),
utils.toindex(start_eid),
)
retg = DGLGraph(gidx, g.ntypes, g.etypes)
# Slice node features
for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid+num_nodes[ntid])
subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats
if store_ids:
retg.nodes[ntype].data[NID] = F.arange(stnid, stnid+num_nodes[ntid],
retg.idtype, retg.device)
retg.nodes[ntype].data[NID] = F.arange(
stnid, stnid + num_nodes[ntid], retg.idtype, retg.device
)
# Slice edge features
for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid]
for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid+num_edges[etid])
subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])
retg.edges[etype].data[key] = subfeats
if store_ids:
retg.edges[etype].data[EID] = F.arange(steid, steid+num_edges[etid],
retg.idtype, retg.device)
retg.edges[etype].data[EID] = F.arange(
steid, steid + num_edges[etid], retg.idtype, retg.device
)
return retg
"""Module for converting graph from/to other object."""
from collections import defaultdict
from collections.abc import Mapping
from scipy.sparse import spmatrix
import numpy as np
import networkx as nx
import numpy as np
from scipy.sparse import spmatrix
from . import backend as F
from . import heterograph_index
from .heterograph import DGLGraph, combine_frames, DGLBlock
from . import graph_index
from . import utils
from .base import NTYPE, ETYPE, NID, EID, DGLError
from . import backend as F, graph_index, heterograph_index, utils
from .base import DGLError, EID, ETYPE, NID, NTYPE
from .heterograph import combine_frames, DGLBlock, DGLGraph
__all__ = [
'graph',
'hetero_from_shared_memory',
'heterograph',
'create_block',
'block_to_graph',
'to_heterogeneous',
'to_homogeneous',
'from_scipy',
'bipartite_from_scipy',
'from_networkx',
'bipartite_from_networkx',
'to_networkx',
'from_cugraph',
'to_cugraph'
"graph",
"hetero_from_shared_memory",
"heterograph",
"create_block",
"block_to_graph",
"to_heterogeneous",
"to_homogeneous",
"from_scipy",
"bipartite_from_scipy",
"from_networkx",
"bipartite_from_networkx",
"to_networkx",
"from_cugraph",
"to_cugraph",
]
def graph(data,
*,
num_nodes=None,
idtype=None,
device=None,
row_sorted=False,
col_sorted=False):
def graph(
data,
*,
num_nodes=None,
idtype=None,
device=None,
row_sorted=False,
col_sorted=False,
):
"""Create a graph and return.
Parameters
......@@ -147,25 +148,41 @@ def graph(data,
from_networkx
"""
if isinstance(data, spmatrix):
raise DGLError("dgl.graph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead.")
raise DGLError(
"dgl.graph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if isinstance(data, nx.Graph):
raise DGLError("dgl.graph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead.")
raise DGLError(
"dgl.graph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead."
)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)
if num_nodes is not None: # override the number of nodes
if num_nodes < max(urange, vrange):
raise DGLError('The num_nodes argument must be larger than the max ID in the data,'
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1))
raise DGLError(
"The num_nodes argument must be larger than the max ID in the data,"
" but got {} and {}.".format(num_nodes, max(urange, vrange) - 1)
)
urange, vrange = num_nodes, num_nodes
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange,
row_sorted=row_sorted, col_sorted=col_sorted)
g = create_from_edges(
sparse_fmt,
arrays,
"_N",
"_E",
"_N",
urange,
vrange,
row_sorted=row_sorted,
col_sorted=col_sorted,
)
return g.to(device)
def hetero_from_shared_memory(name):
"""Create a heterograph from shared memory with the given name.
......@@ -181,13 +198,13 @@ def hetero_from_shared_memory(name):
-------
HeteroGraph (in shared memory)
"""
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(name)
g, ntypes, etypes = heterograph_index.create_heterograph_from_shared_memory(
name
)
return DGLGraph(g, ntypes, etypes)
def heterograph(data_dict,
num_nodes_dict=None,
idtype=None,
device=None):
def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):
"""Create a heterogeneous graph and return.
Parameters
......@@ -300,47 +317,77 @@ def heterograph(data_dict,
num_nodes_dict = defaultdict(int)
for (sty, ety, dty), data in data_dict.items():
if isinstance(data, spmatrix):
raise DGLError("dgl.heterograph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead.")
raise DGLError(
"dgl.heterograph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if isinstance(data, nx.Graph):
raise DGLError("dgl.heterograph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead.")
is_bipartite = (sty != dty)
raise DGLError(
"dgl.heterograph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead."
)
is_bipartite = sty != dty
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=is_bipartite)
data, idtype, bipartite=is_bipartite
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
num_nodes_dict[sty] = max(num_nodes_dict[sty], urange)
num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange)
else: # sanity check
if num_nodes_dict[sty] < urange:
raise DGLError('The given number of nodes of node type {} must be larger than'
' the max ID in the data, but got {} and {}.'.format(
sty, num_nodes_dict[sty], urange - 1))
raise DGLError(
"The given number of nodes of node type {} must be larger than"
" the max ID in the data, but got {} and {}.".format(
sty, num_nodes_dict[sty], urange - 1
)
)
if num_nodes_dict[dty] < vrange:
raise DGLError('The given number of nodes of node type {} must be larger than'
' the max ID in the data, but got {} and {}.'.format(
dty, num_nodes_dict[dty], vrange - 1))
raise DGLError(
"The given number of nodes of node type {} must be larger than"
" the max ID in the data, but got {} and {}.".format(
dty, num_nodes_dict[dty], vrange - 1
)
)
# Create the graph
metagraph, ntypes, etypes, relations = heterograph_index.create_metagraph_index(
num_nodes_dict.keys(), node_tensor_dict.keys())
num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64")
(
metagraph,
ntypes,
etypes,
relations,
) = heterograph_index.create_metagraph_index(
num_nodes_dict.keys(), node_tensor_dict.keys()
)
num_nodes_per_type = utils.toindex(
[num_nodes_dict[ntype] for ntype in ntypes], "int64"
)
rel_graphs = []
for srctype, etype, dsttype in relations:
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, srctype, etype, dsttype,
num_nodes_dict[srctype], num_nodes_dict[dsttype])
g = create_from_edges(
sparse_fmt,
arrays,
srctype,
etype,
dsttype,
num_nodes_dict[srctype],
num_nodes_dict[dsttype],
)
rel_graphs.append(g)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type
)
retg = DGLGraph(hgidx, ntypes, etypes)
return retg.to(device)
def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None):
def create_block(
data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
Parameters
......@@ -464,21 +511,25 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
"""
need_infer = num_src_nodes is None and num_dst_nodes is None
if not isinstance(data_dict, Mapping):
data_dict = {('_N', '_E', '_N'): data_dict}
data_dict = {("_N", "_E", "_N"): data_dict}
if not need_infer:
assert isinstance(num_src_nodes, int), \
"num_src_nodes must be a pair of integers if data_dict is not a dict"
assert isinstance(num_dst_nodes, int), \
"num_dst_nodes must be a pair of integers if data_dict is not a dict"
num_src_nodes = {'_N': num_src_nodes}
num_dst_nodes = {'_N': num_dst_nodes}
assert isinstance(
num_src_nodes, int
), "num_src_nodes must be a pair of integers if data_dict is not a dict"
assert isinstance(
num_dst_nodes, int
), "num_dst_nodes must be a pair of integers if data_dict is not a dict"
num_src_nodes = {"_N": num_src_nodes}
num_dst_nodes = {"_N": num_dst_nodes}
else:
if not need_infer:
assert isinstance(num_src_nodes, Mapping), \
"num_src_nodes must be a dict if data_dict is a dict"
assert isinstance(num_dst_nodes, Mapping), \
"num_dst_nodes must be a dict if data_dict is a dict"
assert isinstance(
num_src_nodes, Mapping
), "num_src_nodes must be a dict if data_dict is a dict"
assert isinstance(
num_dst_nodes, Mapping
), "num_dst_nodes must be a dict if data_dict is a dict"
if need_infer:
num_src_nodes = defaultdict(int)
......@@ -488,20 +539,27 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
node_tensor_dict = {}
for (sty, ety, dty), data in data_dict.items():
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=True)
data, idtype, bipartite=True
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
num_src_nodes[sty] = max(num_src_nodes[sty], urange)
num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)
else: # sanity check
if num_src_nodes[sty] < urange:
raise DGLError('The given number of nodes of source node type {} must be larger'
' than the max ID in the data, but got {} and {}.'.format(
sty, num_src_nodes[sty], urange - 1))
raise DGLError(
"The given number of nodes of source node type {} must be larger"
" than the max ID in the data, but got {} and {}.".format(
sty, num_src_nodes[sty], urange - 1
)
)
if num_dst_nodes[dty] < vrange:
raise DGLError('The given number of nodes of destination node type {} must be'
' larger than the max ID in the data, but got {} and {}.'.format(
dty, num_dst_nodes[dty], vrange - 1))
raise DGLError(
"The given number of nodes of destination node type {} must be"
" larger than the max ID in the data, but got {} and {}.".format(
dty, num_dst_nodes[dty], vrange - 1
)
)
# Create the graph
# Sort the ntypes and relation tuples to have a deterministic order for the same set
......@@ -511,10 +569,14 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
relations = list(sorted(node_tensor_dict.keys()))
num_nodes_per_type = utils.toindex(
[num_src_nodes[ntype] for ntype in srctypes] +
[num_dst_nodes[ntype] for ntype in dsttypes], "int64")
[num_src_nodes[ntype] for ntype in srctypes]
+ [num_dst_nodes[ntype] for ntype in dsttypes],
"int64",
)
srctype_dict = {ntype: i for i, ntype in enumerate(srctypes)}
dsttype_dict = {ntype: i + len(srctypes) for i, ntype in enumerate(dsttypes)}
dsttype_dict = {
ntype: i + len(srctypes) for i, ntype in enumerate(dsttypes)
}
meta_edges_src = []
meta_edges_dst = []
......@@ -525,20 +587,30 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
meta_edges_dst.append(dsttype_dict[dsttype])
etypes.append(etype)
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, 'SRC/' + srctype, etype, 'DST/' + dsttype,
num_src_nodes[srctype], num_dst_nodes[dsttype])
g = create_from_edges(
sparse_fmt,
arrays,
"SRC/" + srctype,
etype,
"DST/" + dsttype,
num_src_nodes[srctype],
num_dst_nodes[dsttype],
)
rel_graphs.append(g)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = graph_index.from_coo(
len(srctypes) + len(dsttypes), meta_edges_src, meta_edges_dst, True)
len(srctypes) + len(dsttypes), meta_edges_src, meta_edges_dst, True
)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type
)
retg = DGLBlock(hgidx, (srctypes, dsttypes), etypes)
return retg.to(device)
def block_to_graph(block):
"""Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`.
......@@ -568,22 +640,26 @@ def block_to_graph(block):
num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2},
metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')])
"""
new_types = [ntype + '_src' for ntype in block.srctypes] + \
[ntype + '_dst' for ntype in block.dsttypes]
new_types = [ntype + "_src" for ntype in block.srctypes] + [
ntype + "_dst" for ntype in block.dsttypes
]
retg = DGLGraph(block._graph, new_types, block.etypes)
for srctype in block.srctypes:
retg.nodes[srctype + '_src'].data.update(block.srcnodes[srctype].data)
retg.nodes[srctype + "_src"].data.update(block.srcnodes[srctype].data)
for dsttype in block.dsttypes:
retg.nodes[dsttype + '_dst'].data.update(block.dstnodes[dsttype].data)
retg.nodes[dsttype + "_dst"].data.update(block.dstnodes[dsttype].data)
for srctype, etype, dsttype in block.canonical_etypes:
retg.edges[srctype + '_src', etype, dsttype + '_dst'].data.update(
block.edges[srctype, etype, dsttype].data)
retg.edges[srctype + "_src", etype, dsttype + "_dst"].data.update(
block.edges[srctype, etype, dsttype].data
)
return retg
def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
etype_field=ETYPE, metagraph=None):
def to_heterogeneous(
G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph=None
):
"""Convert a homogeneous graph to a heterogeneous graph and return.
The input graph should have only one type of nodes and edges. Each node and edge
......@@ -691,10 +767,16 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
--------
to_homogeneous
"""
if (hasattr(G, 'ntypes') and len(G.ntypes) > 1
or hasattr(G, 'etypes') and len(G.etypes) > 1):
raise DGLError('The input graph should be homogeneous and have only one '
' type of nodes and edges.')
if (
hasattr(G, "ntypes")
and len(G.ntypes) > 1
or hasattr(G, "etypes")
and len(G.etypes) > 1
):
raise DGLError(
"The input graph should be homogeneous and have only one "
" type of nodes and edges."
)
num_ntypes = len(ntypes)
idtype = G.idtype
......@@ -706,15 +788,15 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# relabel nodes to per-type local IDs
ntype_count = np.bincount(ntype_ids, minlength=num_ntypes)
ntype_offset = np.insert(np.cumsum(ntype_count), 0, 0)
ntype_ids_sortidx = np.argsort(ntype_ids, kind='stable')
ntype_ids_sortidx = np.argsort(ntype_ids, kind="stable")
ntype_local_ids = np.zeros_like(ntype_ids)
node_groups = []
for i in range(num_ntypes):
node_group = ntype_ids_sortidx[ntype_offset[i]:ntype_offset[i+1]]
node_group = ntype_ids_sortidx[ntype_offset[i] : ntype_offset[i + 1]]
node_groups.append(node_group)
ntype_local_ids[node_group] = np.arange(ntype_count[i])
src, dst = G.all_edges(order='eid')
src, dst = G.all_edges(order="eid")
src = F.asnumpy(src)
dst = F.asnumpy(dst)
src_local = ntype_local_ids[src]
......@@ -729,21 +811,28 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the
# canonical edge type j. We can then group the edges of the same type together.
if metagraph is None:
canonical_etids, _, etype_remapped = \
utils.make_invmap(list(tuple(_) for _ in edge_ctids), False)
etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None])
canonical_etids, _, etype_remapped = utils.make_invmap(
list(tuple(_) for _ in edge_ctids), False
)
etype_mask = (
etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None]
)
else:
ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}
etypes_invmap = {et: i for i, et in enumerate(etypes)}
canonical_etids = []
for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)):
for i, (srctype, dsttype, etype) in enumerate(
metagraph.edges(keys=True)
):
srctype_id = ntypes_invmap[srctype]
etype_id = etypes_invmap[etype]
dsttype_id = ntypes_invmap[dsttype]
canonical_etids.append((srctype_id, etype_id, dsttype_id))
canonical_etids = np.asarray(canonical_etids)
etype_mask = (edge_ctids[None, :] == canonical_etids[:, None]).all(2)
edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))]
edge_groups = [
etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))
]
data_dict = dict()
canonical_etypes = []
......@@ -751,13 +840,12 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
src_of_etype = src_local[edge_groups[i]]
dst_of_etype = dst_local[edge_groups[i]]
canonical_etypes.append((ntypes[stid], etypes[etid], ntypes[dtid]))
data_dict[canonical_etypes[-1]] = \
(src_of_etype, dst_of_etype)
hg = heterograph(data_dict,
dict(zip(ntypes, ntype_count)),
idtype=idtype, device=device)
data_dict[canonical_etypes[-1]] = (src_of_etype, dst_of_etype)
hg = heterograph(
data_dict, dict(zip(ntypes, ntype_count)), idtype=idtype, device=device
)
ntype2ngrp = {ntype : node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
ntype2ngrp = {ntype: node_groups[ntid] for ntid, ntype in enumerate(ntypes)}
# features
for key, data in G.ndata.items():
......@@ -772,19 +860,26 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
continue
for etid in range(len(hg.canonical_etypes)):
rows = F.copy_to(F.tensor(edge_groups[etid]), F.context(data))
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][key] = \
F.gather_row(data, rows)
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][
key
] = F.gather_row(data, rows)
# Record the original IDs of the nodes/edges
for ntid, ntype in enumerate(hg.ntypes):
hg._node_frames[ntid][NID] = F.copy_to(F.tensor(ntype2ngrp[ntype]), device)
hg._node_frames[ntid][NID] = F.copy_to(
F.tensor(ntype2ngrp[ntype]), device
)
for etid in range(len(hg.canonical_etypes)):
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][EID] = \
F.copy_to(F.tensor(edge_groups[etid]), device)
hg._edge_frames[hg.get_etype_id(canonical_etypes[etid])][
EID
] = F.copy_to(F.tensor(edge_groups[etid]), device)
return hg
def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=False):
def to_homogeneous(
G, ndata=None, edata=None, store_type=True, return_count=False
):
"""Convert a heterogeneous graph to a homogeneous graph and return.
By default, the function stores the node and edge types of the input graph as
......@@ -902,7 +997,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype
src, dst = G.all_edges(etype=etype, order='eid')
src, dst = G.all_edges(etype=etype, order="eid")
num_edges = len(src)
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
......@@ -913,16 +1008,24 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
etype_count.append(num_edges)
eids.append(F.arange(0, num_edges, G.idtype, G.device))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
idtype=G.idtype, device=G.device)
retg = graph(
(F.cat(srcs, 0), F.cat(dsts, 0)),
num_nodes=total_num_nodes,
idtype=G.idtype,
device=G.device,
)
# copy features
if ndata is None:
ndata = []
if edata is None:
edata = []
comb_nf = combine_frames(G._node_frames, range(len(G.ntypes)), col_names=ndata)
comb_ef = combine_frames(G._edge_frames, range(len(G.etypes)), col_names=edata)
comb_nf = combine_frames(
G._node_frames, range(len(G.ntypes)), col_names=ndata
)
comb_ef = combine_frames(
G._edge_frames, range(len(G.etypes)), col_names=edata
)
if comb_nf is not None:
retg.ndata.update(comb_nf)
if comb_ef is not None:
......@@ -939,10 +1042,8 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
else:
return retg
def from_scipy(sp_mat,
eweight_name=None,
idtype=None,
device=None):
def from_scipy(sp_mat, eweight_name=None, idtype=None, device=None):
"""Create a graph from a SciPy sparse matrix and return.
Parameters
......@@ -1019,20 +1120,23 @@ def from_scipy(sp_mat,
num_rows = sp_mat.shape[0]
num_cols = sp_mat.shape[1]
if num_rows != num_cols:
raise DGLError('Expect the number of rows to be the same as the number of columns for '
'sp_mat, got {:d} and {:d}.'.format(num_rows, num_cols))
raise DGLError(
"Expect the number of rows to be the same as the number of columns for "
"sp_mat, got {:d} and {:d}.".format(num_rows, num_cols)
)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
sp_mat, idtype
)
g = create_from_edges(sparse_fmt, arrays, "_N", "_E", "_N", urange, vrange)
if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device)
def bipartite_from_scipy(sp_mat,
utype, etype, vtype,
eweight_name=None,
idtype=None,
device=None):
def bipartite_from_scipy(
sp_mat, utype, etype, vtype, eweight_name=None, idtype=None, device=None
):
"""Create a uni-directional bipartite graph from a SciPy sparse matrix and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
......@@ -1116,18 +1220,25 @@ def bipartite_from_scipy(sp_mat,
heterograph
bipartite_from_networkx
"""
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(sp_mat, idtype, bipartite=True)
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange)
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
sp_mat, idtype, bipartite=True
)
g = create_from_edges(
sparse_fmt, arrays, utype, etype, vtype, urange, vrange
)
if eweight_name is not None:
g.edata[eweight_name] = F.tensor(sp_mat.data)
return g.to(device)
def from_networkx(nx_graph,
node_attrs=None,
edge_attrs=None,
edge_id_attr_name=None,
idtype=None,
device=None):
def from_networkx(
nx_graph,
node_attrs=None,
edge_attrs=None,
edge_id_attr_name=None,
idtype=None,
device=None,
):
"""Create a graph from a NetworkX graph and return.
.. note::
......@@ -1221,27 +1332,38 @@ def from_networkx(nx_graph,
from_scipy
"""
# Sanity check
if edge_id_attr_name is not None and \
edge_id_attr_name not in next(iter(nx_graph.edges(data=True)))[-1]:
raise DGLError('Failed to find the pre-specified edge IDs in the edge features of '
'the NetworkX graph with name {}'.format(edge_id_attr_name))
if not nx_graph.is_directed() and not (edge_id_attr_name is None and edge_attrs is None):
raise DGLError('Expect edge_id_attr_name and edge_attrs to be None when nx_graph is '
'undirected, got {} and {}'.format(edge_id_attr_name, edge_attrs))
if (
edge_id_attr_name is not None
and edge_id_attr_name not in next(iter(nx_graph.edges(data=True)))[-1]
):
raise DGLError(
"Failed to find the pre-specified edge IDs in the edge features of "
"the NetworkX graph with name {}".format(edge_id_attr_name)
)
if not nx_graph.is_directed() and not (
edge_id_attr_name is None and edge_attrs is None
):
raise DGLError(
"Expect edge_id_attr_name and edge_attrs to be None when nx_graph is "
"undirected, got {} and {}".format(edge_id_attr_name, edge_attrs)
)
# Relabel nodes using consecutive integers starting from 0
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering="sorted")
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, edge_id_attr_name=edge_id_attr_name)
nx_graph, idtype, edge_id_attr_name=edge_id_attr_name
)
g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange)
g = create_from_edges(sparse_fmt, arrays, "_N", "_E", "_N", urange, vrange)
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
has_edge_id = (
nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
)
# handle features
# copy attributes
......@@ -1250,6 +1372,7 @@ def from_networkx(nx_graph,
return F.cat([F.unsqueeze(x, 0) for x in lst], dim=0)
else:
return F.tensor(lst)
if node_attrs is not None:
# mapping from feature name to a list of tensors to be concatenated
attr_dict = defaultdict(list)
......@@ -1269,9 +1392,11 @@ def from_networkx(nx_graph,
num_edges = g.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True):
if attrs[edge_id_attr_name] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be'
' smaller than the number of edges --'
' {}, got {}.'.format(num_edges, attrs['id']))
raise DGLError(
"Expect the pre-specified edge ids to be"
" smaller than the number of edges --"
" {}, got {}.".format(num_edges, attrs["id"])
)
for key in edge_attrs:
attr_dict[key][attrs[edge_id_attr_name]] = attrs[key]
else:
......@@ -1283,17 +1408,26 @@ def from_networkx(nx_graph,
for attr in edge_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
raise DGLError(
"Not all edges have attribute {}.".format(attr)
)
g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)
return g.to(device)
def bipartite_from_networkx(nx_graph,
utype, etype, vtype,
u_attrs=None, e_attrs=None, v_attrs=None,
edge_id_attr_name=None,
idtype=None,
device=None):
def bipartite_from_networkx(
nx_graph,
utype,
etype,
vtype,
u_attrs=None,
e_attrs=None,
v_attrs=None,
edge_id_attr_name=None,
idtype=None,
device=None,
):
"""Create a unidirectional bipartite graph from a NetworkX graph and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
......@@ -1403,42 +1537,58 @@ def bipartite_from_networkx(nx_graph,
bipartite_from_scipy
"""
if not nx_graph.is_directed():
raise DGLError('Expect nx_graph to be a directed NetworkX graph.')
if edge_id_attr_name is not None and \
not edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]:
raise DGLError('Failed to find the pre-specified edge IDs in the edge features '
'of the NetworkX graph with name {}'.format(edge_id_attr_name))
raise DGLError("Expect nx_graph to be a directed NetworkX graph.")
if (
edge_id_attr_name is not None
and not edge_id_attr_name in next(iter(nx_graph.edges(data=True)))[-1]
):
raise DGLError(
"Failed to find the pre-specified edge IDs in the edge features "
"of the NetworkX graph with name {}".format(edge_id_attr_name)
)
# Get the source and destination node sets
top_nodes = set()
bottom_nodes = set()
for n, ndata in nx_graph.nodes(data=True):
if 'bipartite' not in ndata:
raise DGLError('Expect the node {} to have attribute bipartite'.format(n))
if ndata['bipartite'] == 0:
if "bipartite" not in ndata:
raise DGLError(
"Expect the node {} to have attribute bipartite".format(n)
)
if ndata["bipartite"] == 0:
top_nodes.add(n)
elif ndata['bipartite'] == 1:
elif ndata["bipartite"] == 1:
bottom_nodes.add(n)
else:
raise ValueError('Expect the bipartite attribute of the node {} to be 0 or 1, '
'got {}'.format(n, ndata['bipartite']))
raise ValueError(
"Expect the bipartite attribute of the node {} to be 0 or 1, "
"got {}".format(n, ndata["bipartite"])
)
# Separately relabel the source and destination nodes.
top_nodes = sorted(top_nodes)
bottom_nodes = sorted(bottom_nodes)
top_map = {n : i for i, n in enumerate(top_nodes)}
bottom_map = {n : i for i, n in enumerate(bottom_nodes)}
top_map = {n: i for i, n in enumerate(top_nodes)}
bottom_map = {n: i for i, n in enumerate(bottom_nodes)}
# Get the node tensors and the number of nodes
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
nx_graph, idtype, bipartite=True,
nx_graph,
idtype,
bipartite=True,
edge_id_attr_name=edge_id_attr_name,
top_map=top_map, bottom_map=bottom_map)
top_map=top_map,
bottom_map=bottom_map,
)
g = create_from_edges(sparse_fmt, arrays, utype, etype, vtype, urange, vrange)
g = create_from_edges(
sparse_fmt, arrays, utype, etype, vtype, urange, vrange
)
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
has_edge_id = (
nx_graph.number_of_edges() > 0 and edge_id_attr_name is not None
)
# handle features
# copy attributes
......@@ -1485,11 +1635,14 @@ def bipartite_from_networkx(nx_graph,
for attr in e_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
raise DGLError(
"Not all edges have attribute {}.".format(attr)
)
g.edata[attr] = F.copy_to(_batcher(attr_dict[attr]), g.device)
return g.to(device)
def to_networkx(g, node_attrs=None, edge_attrs=None):
"""Convert a homogeneous graph to a NetworkX graph and return.
......@@ -1537,9 +1690,11 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})])
"""
if g.device != F.cpu():
raise DGLError('Cannot convert a CUDA graph to networkx. Call g.cpu() first.')
raise DGLError(
"Cannot convert a CUDA graph to networkx. Call g.cpu() first."
)
if not g.is_homogeneous:
raise DGLError('dgl.to_networkx only supports homogeneous graphs.')
raise DGLError("dgl.to_networkx only supports homogeneous graphs.")
src, dst = g.edges()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
......@@ -1552,16 +1707,22 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
feat_dict = g._get_n_repr(0, nid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in node_attrs})
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in node_attrs}
)
if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True):
eid = attr['id']
eid = attr["id"]
feat_dict = g._get_e_repr(0, eid)
attr.update({key: F.squeeze(feat_dict[key], 0) for key in edge_attrs})
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}
)
return nx_graph
DGLGraph.to_networkx = to_networkx
def to_cugraph(g):
"""Convert a DGL graph to a :class:`cugraph.Graph` and return.
......@@ -1595,30 +1756,36 @@ def to_cugraph(g):
1 1 1
"""
if g.device.type != 'cuda':
raise DGLError(f"Cannot convert a {g.device.type} graph to cugraph." +
"Call g.to('cuda') first.")
if g.device.type != "cuda":
raise DGLError(
f"Cannot convert a {g.device.type} graph to cugraph."
+ "Call g.to('cuda') first."
)
if not g.is_homogeneous:
raise DGLError("dgl.to_cugraph only supports homogeneous graphs.")
try:
import cugraph
import cudf
import cugraph
except ModuleNotFoundError:
raise ModuleNotFoundError("to_cugraph requires cugraph which could not be imported")
raise ModuleNotFoundError(
"to_cugraph requires cugraph which could not be imported"
)
edgelist = g.edges()
src_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[0]))
dst_ser = cudf.from_dlpack(F.zerocopy_to_dlpack(edgelist[1]))
cudf_data = cudf.DataFrame({'source':src_ser, 'destination':dst_ser})
cudf_data = cudf.DataFrame({"source": src_ser, "destination": dst_ser})
g_cugraph = cugraph.Graph(directed=True)
g_cugraph.from_cudf_edgelist(cudf_data,
source='source',
destination='destination')
g_cugraph.from_cudf_edgelist(
cudf_data, source="source", destination="destination"
)
return g_cugraph
DGLGraph.to_cugraph = to_cugraph
def from_cugraph(cugraph_graph):
"""Create a graph from a :class:`cugraph.Graph` object.
......@@ -1660,21 +1827,29 @@ def from_cugraph(cugraph_graph):
cugraph_graph = cugraph_graph.to_directed()
edges = cugraph_graph.edges()
src_t = F.zerocopy_from_dlpack(edges['src'].to_dlpack())
dst_t = F.zerocopy_from_dlpack(edges['dst'].to_dlpack())
g = graph((src_t,dst_t))
src_t = F.zerocopy_from_dlpack(edges["src"].to_dlpack())
dst_t = F.zerocopy_from_dlpack(edges["dst"].to_dlpack())
g = graph((src_t, dst_t))
return g
############################################################
# Internal APIs
############################################################
def create_from_edges(sparse_fmt, arrays,
utype, etype, vtype,
urange, vrange,
row_sorted=False,
col_sorted=False):
def create_from_edges(
sparse_fmt,
arrays,
utype,
etype,
vtype,
urange,
vrange,
row_sorted=False,
col_sorted=False,
):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
......@@ -1713,16 +1888,30 @@ def create_from_edges(sparse_fmt, arrays,
else:
num_ntypes = 2
if sparse_fmt == 'coo':
if sparse_fmt == "coo":
u, v = arrays
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted)
else: # 'csr' or 'csc'
num_ntypes,
urange,
vrange,
u,
v,
["coo", "csr", "csc"],
row_sorted,
col_sorted,
)
else: # 'csr' or 'csc'
indptr, indices, eids = arrays
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, urange, vrange, indptr, indices, eids, ['coo', 'csr', 'csc'],
sparse_fmt == 'csc')
num_ntypes,
urange,
vrange,
indptr,
indices,
eids,
["coo", "csr", "csc"],
sparse_fmt == "csc",
)
if utype == vtype:
return DGLGraph(hgidx, [utype], [etype])
else:
......
......@@ -2,10 +2,8 @@
# pylint: disable=not-callable
import numpy as np
from . import backend as F
from . import function as fn
from . import ops
from .base import ALL, EID, NID, DGLError, dgl_warning, is_all
from . import backend as F, function as fn, ops
from .base import ALL, dgl_warning, DGLError, EID, is_all, NID
from .frame import Frame
from .udf import EdgeBatch, NodeBatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment