Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a566b60b
Unverified
Commit
a566b60b
authored
Feb 21, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 21, 2023
Browse files
auto-fix (#5331)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
d1827488
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
517 additions
and
282 deletions
+517
-282
python/dgl/_ffi/function.py
python/dgl/_ffi/function.py
+3
-3
python/dgl/_ffi/ndarray.py
python/dgl/_ffi/ndarray.py
+4
-4
python/dgl/_ffi/object.py
python/dgl/_ffi/object.py
+4
-7
python/dgl/batch.py
python/dgl/batch.py
+103
-52
python/dgl/convert.py
python/dgl/convert.py
+401
-212
python/dgl/core.py
python/dgl/core.py
+2
-4
No files found.
python/dgl/_ffi/function.py
View file @
a566b60b
...
@@ -14,26 +14,26 @@ try:
...
@@ -14,26 +14,26 @@ try:
if
_FFI_MODE
==
"ctypes"
:
if
_FFI_MODE
==
"ctypes"
:
raise
ImportError
()
raise
ImportError
()
if
sys
.
version_info
>=
(
3
,
0
):
if
sys
.
version_info
>=
(
3
,
0
):
from
._cy3.core
import
FunctionBase
as
_FunctionBase
from
._cy3.core
import
(
from
._cy3.core
import
(
_set_class_function
,
_set_class_function
,
_set_class_module
,
_set_class_module
,
convert_to_dgl_func
,
convert_to_dgl_func
,
FunctionBase
as
_FunctionBase
,
)
)
else
:
else
:
from
._cy2.core
import
FunctionBase
as
_FunctionBase
from
._cy2.core
import
(
from
._cy2.core
import
(
_set_class_function
,
_set_class_function
,
_set_class_module
,
_set_class_module
,
convert_to_dgl_func
,
convert_to_dgl_func
,
FunctionBase
as
_FunctionBase
,
)
)
except
IMPORT_EXCEPT
:
except
IMPORT_EXCEPT
:
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from
._ctypes.function
import
FunctionBase
as
_FunctionBase
from
._ctypes.function
import
(
from
._ctypes.function
import
(
_set_class_function
,
_set_class_function
,
_set_class_module
,
_set_class_module
,
convert_to_dgl_func
,
convert_to_dgl_func
,
FunctionBase
as
_FunctionBase
,
)
)
FunctionHandle
=
ctypes
.
c_void_p
FunctionHandle
=
ctypes
.
c_void_p
...
...
python/dgl/_ffi/ndarray.py
View file @
a566b60b
...
@@ -9,12 +9,12 @@ import numpy as np
...
@@ -9,12 +9,12 @@ import numpy as np
from
.base
import
_FFI_MODE
,
_LIB
,
c_array
,
c_str
,
check_call
,
string_types
from
.base
import
_FFI_MODE
,
_LIB
,
c_array
,
c_str
,
check_call
,
string_types
from
.runtime_ctypes
import
(
from
.runtime_ctypes
import
(
dgl_shape_index_t
,
DGLArray
,
DGLArray
,
DGLArrayHandle
,
DGLArrayHandle
,
DGLContext
,
DGLContext
,
DGLDataType
,
DGLDataType
,
TypeCode
,
TypeCode
,
dgl_shape_index_t
,
)
)
IMPORT_EXCEPT
=
RuntimeError
if
_FFI_MODE
==
"cython"
else
ImportError
IMPORT_EXCEPT
=
RuntimeError
if
_FFI_MODE
==
"cython"
else
ImportError
...
@@ -24,29 +24,29 @@ try:
...
@@ -24,29 +24,29 @@ try:
if
_FFI_MODE
==
"ctypes"
:
if
_FFI_MODE
==
"ctypes"
:
raise
ImportError
()
raise
ImportError
()
if
sys
.
version_info
>=
(
3
,
0
):
if
sys
.
version_info
>=
(
3
,
0
):
from
._cy3.core
import
NDArrayBase
as
_NDArrayBase
from
._cy3.core
import
(
from
._cy3.core
import
(
_from_dlpack
,
_from_dlpack
,
_make_array
,
_make_array
,
_reg_extension
,
_reg_extension
,
_set_class_ndarray
,
_set_class_ndarray
,
NDArrayBase
as
_NDArrayBase
,
)
)
else
:
else
:
from
._cy2.core
import
NDArrayBase
as
_NDArrayBase
from
._cy2.core
import
(
from
._cy2.core
import
(
_from_dlpack
,
_from_dlpack
,
_make_array
,
_make_array
,
_reg_extension
,
_reg_extension
,
_set_class_ndarray
,
_set_class_ndarray
,
NDArrayBase
as
_NDArrayBase
,
)
)
except
IMPORT_EXCEPT
:
except
IMPORT_EXCEPT
:
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from
._ctypes.ndarray
import
NDArrayBase
as
_NDArrayBase
from
._ctypes.ndarray
import
(
from
._ctypes.ndarray
import
(
_from_dlpack
,
_from_dlpack
,
_make_array
,
_make_array
,
_reg_extension
,
_reg_extension
,
_set_class_ndarray
,
_set_class_ndarray
,
NDArrayBase
as
_NDArrayBase
,
)
)
...
...
python/dgl/_ffi/object.py
View file @
a566b60b
...
@@ -7,7 +7,7 @@ import sys
...
@@ -7,7 +7,7 @@ import sys
from
..
import
_api_internal
from
..
import
_api_internal
from
.base
import
_FFI_MODE
,
_LIB
,
c_str
,
check_call
,
py_str
from
.base
import
_FFI_MODE
,
_LIB
,
c_str
,
check_call
,
py_str
from
.object_generic
import
ObjectGeneric
,
convert_to_object
from
.object_generic
import
convert_to_object
,
ObjectGeneric
# pylint: disable=invalid-name
# pylint: disable=invalid-name
IMPORT_EXCEPT
=
RuntimeError
if
_FFI_MODE
==
"cython"
else
ImportError
IMPORT_EXCEPT
=
RuntimeError
if
_FFI_MODE
==
"cython"
else
ImportError
...
@@ -16,15 +16,12 @@ try:
...
@@ -16,15 +16,12 @@ try:
if
_FFI_MODE
==
"ctypes"
:
if
_FFI_MODE
==
"ctypes"
:
raise
ImportError
()
raise
ImportError
()
if
sys
.
version_info
>=
(
3
,
0
):
if
sys
.
version_info
>=
(
3
,
0
):
from
._cy3.core
import
ObjectBase
as
_ObjectBase
from
._cy3.core
import
_register_object
,
ObjectBase
as
_ObjectBase
from
._cy3.core
import
_register_object
else
:
else
:
from
._cy2.core
import
ObjectBase
as
_ObjectBase
from
._cy2.core
import
_register_object
,
ObjectBase
as
_ObjectBase
from
._cy2.core
import
_register_object
except
IMPORT_EXCEPT
:
except
IMPORT_EXCEPT
:
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from
._ctypes.object
import
ObjectBase
as
_ObjectBase
from
._ctypes.object
import
_register_object
,
ObjectBase
as
_ObjectBase
from
._ctypes.object
import
_register_object
def
_new_object
(
cls
):
def
_new_object
(
cls
):
...
...
python/dgl/batch.py
View file @
a566b60b
"""Utilities for batching/unbatching graphs."""
"""Utilities for batching/unbatching graphs."""
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
.
import
backend
as
F
from
.
import
backend
as
F
,
convert
,
utils
from
.base
import
ALL
,
is_all
,
DGLError
,
NID
,
EID
from
.base
import
ALL
,
DGLError
,
EID
,
is_all
,
NID
from
.heterograph_index
import
disjoint_union
,
slice_gidx
from
.heterograph
import
DGLGraph
from
.heterograph
import
DGLGraph
from
.
import
convert
from
.
heterograph_index
import
disjoint_union
,
slice_gidx
from
.
import
utils
__all__
=
[
"batch"
,
"unbatch"
,
"slice_batch"
]
__all__
=
[
'batch'
,
'unbatch'
,
'slice_batch'
]
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
):
def
batch
(
graphs
,
ndata
=
ALL
,
edata
=
ALL
):
r
"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
r
"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
...
@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
...
@@ -149,13 +148,19 @@ def batch(graphs, ndata=ALL, edata=ALL):
unbatch
unbatch
"""
"""
if
len
(
graphs
)
==
0
:
if
len
(
graphs
)
==
0
:
raise
DGLError
(
'
The input list of graphs cannot be empty.
'
)
raise
DGLError
(
"
The input list of graphs cannot be empty.
"
)
if
not
(
is_all
(
ndata
)
or
isinstance
(
ndata
,
list
)
or
ndata
is
None
):
if
not
(
is_all
(
ndata
)
or
isinstance
(
ndata
,
list
)
or
ndata
is
None
):
raise
DGLError
(
'Invalid argument ndata: must be a string list but got {}.'
.
format
(
raise
DGLError
(
type
(
ndata
)))
"Invalid argument ndata: must be a string list but got {}."
.
format
(
type
(
ndata
)
)
)
if
not
(
is_all
(
edata
)
or
isinstance
(
edata
,
list
)
or
edata
is
None
):
if
not
(
is_all
(
edata
)
or
isinstance
(
edata
,
list
)
or
edata
is
None
):
raise
DGLError
(
'Invalid argument edata: must be a string list but got {}.'
.
format
(
raise
DGLError
(
type
(
edata
)))
"Invalid argument edata: must be a string list but got {}."
.
format
(
type
(
edata
)
)
)
if
any
(
g
.
is_block
for
g
in
graphs
):
if
any
(
g
.
is_block
for
g
in
graphs
):
raise
DGLError
(
"Batching a MFG is not supported."
)
raise
DGLError
(
"Batching a MFG is not supported."
)
...
@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
...
@@ -165,7 +170,9 @@ def batch(graphs, ndata=ALL, edata=ALL):
ntype_ids
=
[
graphs
[
0
].
get_ntype_id
(
n
)
for
n
in
ntypes
]
ntype_ids
=
[
graphs
[
0
].
get_ntype_id
(
n
)
for
n
in
ntypes
]
etypes
=
[
etype
for
_
,
etype
,
_
in
relations
]
etypes
=
[
etype
for
_
,
etype
,
_
in
relations
]
gidx
=
disjoint_union
(
graphs
[
0
].
_graph
.
metagraph
,
[
g
.
_graph
for
g
in
graphs
])
gidx
=
disjoint_union
(
graphs
[
0
].
_graph
.
metagraph
,
[
g
.
_graph
for
g
in
graphs
]
)
retg
=
DGLGraph
(
gidx
,
ntypes
,
etypes
)
retg
=
DGLGraph
(
gidx
,
ntypes
,
etypes
)
# Compute batch num nodes
# Compute batch num nodes
...
@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
...
@@ -183,29 +190,42 @@ def batch(graphs, ndata=ALL, edata=ALL):
# Batch node feature
# Batch node feature
if
ndata
is
not
None
:
if
ndata
is
not
None
:
for
ntype_id
,
ntype
in
zip
(
ntype_ids
,
ntypes
):
for
ntype_id
,
ntype
in
zip
(
ntype_ids
,
ntypes
):
all_empty
=
all
(
g
.
_graph
.
number_of_nodes
(
ntype_id
)
==
0
for
g
in
graphs
)
all_empty
=
all
(
g
.
_graph
.
number_of_nodes
(
ntype_id
)
==
0
for
g
in
graphs
)
frames
=
[
frames
=
[
g
.
_node_frames
[
ntype_id
]
for
g
in
graphs
g
.
_node_frames
[
ntype_id
]
if
g
.
_graph
.
number_of_nodes
(
ntype_id
)
>
0
or
all_empty
]
for
g
in
graphs
if
g
.
_graph
.
number_of_nodes
(
ntype_id
)
>
0
or
all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
# we allow empty graphs to have no features during batching.
ret_feat
=
_batch_feat_dicts
(
frames
,
ndata
,
'nodes["{}"].data'
.
format
(
ntype
))
ret_feat
=
_batch_feat_dicts
(
frames
,
ndata
,
'nodes["{}"].data'
.
format
(
ntype
)
)
retg
.
nodes
[
ntype
].
data
.
update
(
ret_feat
)
retg
.
nodes
[
ntype
].
data
.
update
(
ret_feat
)
# Batch edge feature
# Batch edge feature
if
edata
is
not
None
:
if
edata
is
not
None
:
for
etype_id
,
etype
in
zip
(
relation_ids
,
relations
):
for
etype_id
,
etype
in
zip
(
relation_ids
,
relations
):
all_empty
=
all
(
g
.
_graph
.
number_of_edges
(
etype_id
)
==
0
for
g
in
graphs
)
all_empty
=
all
(
g
.
_graph
.
number_of_edges
(
etype_id
)
==
0
for
g
in
graphs
)
frames
=
[
frames
=
[
g
.
_edge_frames
[
etype_id
]
for
g
in
graphs
g
.
_edge_frames
[
etype_id
]
if
g
.
_graph
.
number_of_edges
(
etype_id
)
>
0
or
all_empty
]
for
g
in
graphs
if
g
.
_graph
.
number_of_edges
(
etype_id
)
>
0
or
all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
# we allow empty graphs to have no features during batching.
ret_feat
=
_batch_feat_dicts
(
frames
,
edata
,
'edges[{}].data'
.
format
(
etype
))
ret_feat
=
_batch_feat_dicts
(
frames
,
edata
,
"edges[{}].data"
.
format
(
etype
)
)
retg
.
edges
[
etype
].
data
.
update
(
ret_feat
)
retg
.
edges
[
etype
].
data
.
update
(
ret_feat
)
return
retg
return
retg
def
_batch_feat_dicts
(
frames
,
keys
,
feat_dict_name
):
def
_batch_feat_dicts
(
frames
,
keys
,
feat_dict_name
):
"""Internal function to batch feature dictionaries.
"""Internal function to batch feature dictionaries.
...
@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
...
@@ -233,9 +253,10 @@ def _batch_feat_dicts(frames, keys, feat_dict_name):
else
:
else
:
utils
.
check_all_same_schema_for_keys
(
schemas
,
keys
,
feat_dict_name
)
utils
.
check_all_same_schema_for_keys
(
schemas
,
keys
,
feat_dict_name
)
# concat features
# concat features
ret_feat
=
{
k
:
F
.
cat
([
fd
[
k
]
for
fd
in
frames
],
0
)
for
k
in
keys
}
ret_feat
=
{
k
:
F
.
cat
([
fd
[
k
]
for
fd
in
frames
],
0
)
for
k
in
keys
}
return
ret_feat
return
ret_feat
def
unbatch
(
g
,
node_split
=
None
,
edge_split
=
None
):
def
unbatch
(
g
,
node_split
=
None
,
edge_split
=
None
):
"""Revert the batch operation by split the given graph into a list of small ones.
"""Revert the batch operation by split the given graph into a list of small ones.
...
@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
...
@@ -339,57 +360,75 @@ def unbatch(g, node_split=None, edge_split=None):
num_split
=
None
num_split
=
None
# Parse node_split
# Parse node_split
if
node_split
is
None
:
if
node_split
is
None
:
node_split
=
{
ntype
:
g
.
batch_num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
}
node_split
=
{
ntype
:
g
.
batch_num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
}
elif
not
isinstance
(
node_split
,
Mapping
):
elif
not
isinstance
(
node_split
,
Mapping
):
if
len
(
g
.
ntypes
)
!=
1
:
if
len
(
g
.
ntypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary for argument node_split when'
raise
DGLError
(
' there are multiple node types.'
)
"Must provide a dictionary for argument node_split when"
node_split
=
{
g
.
ntypes
[
0
]
:
node_split
}
" there are multiple node types."
)
node_split
=
{
g
.
ntypes
[
0
]:
node_split
}
if
node_split
.
keys
()
!=
set
(
g
.
ntypes
):
if
node_split
.
keys
()
!=
set
(
g
.
ntypes
):
raise
DGLError
(
'
Must specify node_split for each node type.
'
)
raise
DGLError
(
"
Must specify node_split for each node type.
"
)
for
split
in
node_split
.
values
():
for
split
in
node_split
.
values
():
if
num_split
is
not
None
and
num_split
!=
len
(
split
):
if
num_split
is
not
None
and
num_split
!=
len
(
split
):
raise
DGLError
(
'All node_split and edge_split must specify the same number'
raise
DGLError
(
' of split sizes.'
)
"All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split
=
len
(
split
)
num_split
=
len
(
split
)
# Parse edge_split
# Parse edge_split
if
edge_split
is
None
:
if
edge_split
is
None
:
edge_split
=
{
etype
:
g
.
batch_num_edges
(
etype
)
for
etype
in
g
.
canonical_etypes
}
edge_split
=
{
etype
:
g
.
batch_num_edges
(
etype
)
for
etype
in
g
.
canonical_etypes
}
elif
not
isinstance
(
edge_split
,
Mapping
):
elif
not
isinstance
(
edge_split
,
Mapping
):
if
len
(
g
.
etypes
)
!=
1
:
if
len
(
g
.
etypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary for argument edge_split when'
raise
DGLError
(
' there are multiple edge types.'
)
"Must provide a dictionary for argument edge_split when"
edge_split
=
{
g
.
canonical_etypes
[
0
]
:
edge_split
}
" there are multiple edge types."
)
edge_split
=
{
g
.
canonical_etypes
[
0
]:
edge_split
}
if
edge_split
.
keys
()
!=
set
(
g
.
canonical_etypes
):
if
edge_split
.
keys
()
!=
set
(
g
.
canonical_etypes
):
raise
DGLError
(
'
Must specify edge_split for each canonical edge type.
'
)
raise
DGLError
(
"
Must specify edge_split for each canonical edge type.
"
)
for
split
in
edge_split
.
values
():
for
split
in
edge_split
.
values
():
if
num_split
is
not
None
and
num_split
!=
len
(
split
):
if
num_split
is
not
None
and
num_split
!=
len
(
split
):
raise
DGLError
(
'All edge_split and edge_split must specify the same number'
raise
DGLError
(
' of split sizes.'
)
"All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split
=
len
(
split
)
num_split
=
len
(
split
)
node_split
=
{
k
:
F
.
asnumpy
(
split
).
tolist
()
for
k
,
split
in
node_split
.
items
()}
node_split
=
{
edge_split
=
{
k
:
F
.
asnumpy
(
split
).
tolist
()
for
k
,
split
in
edge_split
.
items
()}
k
:
F
.
asnumpy
(
split
).
tolist
()
for
k
,
split
in
node_split
.
items
()
}
edge_split
=
{
k
:
F
.
asnumpy
(
split
).
tolist
()
for
k
,
split
in
edge_split
.
items
()
}
# Split edges for each relation
# Split edges for each relation
edge_dict_per
=
[{}
for
i
in
range
(
num_split
)]
edge_dict_per
=
[{}
for
i
in
range
(
num_split
)]
for
rel
in
g
.
canonical_etypes
:
for
rel
in
g
.
canonical_etypes
:
srctype
,
etype
,
dsttype
=
rel
srctype
,
etype
,
dsttype
=
rel
srcnid_off
=
dstnid_off
=
0
srcnid_off
=
dstnid_off
=
0
u
,
v
=
g
.
edges
(
order
=
'
eid
'
,
etype
=
rel
)
u
,
v
=
g
.
edges
(
order
=
"
eid
"
,
etype
=
rel
)
us
=
F
.
split
(
u
,
edge_split
[
rel
],
0
)
us
=
F
.
split
(
u
,
edge_split
[
rel
],
0
)
vs
=
F
.
split
(
v
,
edge_split
[
rel
],
0
)
vs
=
F
.
split
(
v
,
edge_split
[
rel
],
0
)
for
i
,
(
subu
,
subv
)
in
enumerate
(
zip
(
us
,
vs
)):
for
i
,
(
subu
,
subv
)
in
enumerate
(
zip
(
us
,
vs
)):
edge_dict_per
[
i
][
rel
]
=
(
subu
-
srcnid_off
,
subv
-
dstnid_off
)
edge_dict_per
[
i
][
rel
]
=
(
subu
-
srcnid_off
,
subv
-
dstnid_off
)
srcnid_off
+=
node_split
[
srctype
][
i
]
srcnid_off
+=
node_split
[
srctype
][
i
]
dstnid_off
+=
node_split
[
dsttype
][
i
]
dstnid_off
+=
node_split
[
dsttype
][
i
]
num_nodes_dict_per
=
[{
k
:
split
[
i
]
for
k
,
split
in
node_split
.
items
()}
num_nodes_dict_per
=
[
for
i
in
range
(
num_split
)]
{
k
:
split
[
i
]
for
k
,
split
in
node_split
.
items
()}
for
i
in
range
(
num_split
)
]
# Create graphs
# Create graphs
gs
=
[
convert
.
heterograph
(
edge_dict
,
num_nodes_dict
,
idtype
=
g
.
idtype
)
gs
=
[
for
edge_dict
,
num_nodes_dict
in
zip
(
edge_dict_per
,
num_nodes_dict_per
)]
convert
.
heterograph
(
edge_dict
,
num_nodes_dict
,
idtype
=
g
.
idtype
)
for
edge_dict
,
num_nodes_dict
in
zip
(
edge_dict_per
,
num_nodes_dict_per
)
]
# Unbatch node features
# Unbatch node features
for
ntype
in
g
.
ntypes
:
for
ntype
in
g
.
ntypes
:
...
@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
...
@@ -407,6 +446,7 @@ def unbatch(g, node_split=None, edge_split=None):
return
gs
return
gs
def
slice_batch
(
g
,
gid
,
store_ids
=
False
):
def
slice_batch
(
g
,
gid
,
store_ids
=
False
):
"""Get a particular graph from a batch of graphs.
"""Get a particular graph from a batch of graphs.
...
@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
...
@@ -455,7 +495,9 @@ def slice_batch(g, gid, store_ids=False):
if
gid
==
0
:
if
gid
==
0
:
start_nid
.
append
(
0
)
start_nid
.
append
(
0
)
else
:
else
:
start_nid
.
append
(
F
.
as_scalar
(
F
.
sum
(
F
.
slice_axis
(
batch_num_nodes
,
0
,
0
,
gid
),
0
)))
start_nid
.
append
(
F
.
as_scalar
(
F
.
sum
(
F
.
slice_axis
(
batch_num_nodes
,
0
,
0
,
gid
),
0
))
)
start_eid
=
[]
start_eid
=
[]
num_edges
=
[]
num_edges
=
[]
...
@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
...
@@ -465,33 +507,42 @@ def slice_batch(g, gid, store_ids=False):
if
gid
==
0
:
if
gid
==
0
:
start_eid
.
append
(
0
)
start_eid
.
append
(
0
)
else
:
else
:
start_eid
.
append
(
F
.
as_scalar
(
F
.
sum
(
F
.
slice_axis
(
batch_num_edges
,
0
,
0
,
gid
),
0
)))
start_eid
.
append
(
F
.
as_scalar
(
F
.
sum
(
F
.
slice_axis
(
batch_num_edges
,
0
,
0
,
gid
),
0
))
)
# Slice graph structure
# Slice graph structure
gidx
=
slice_gidx
(
g
.
_graph
,
utils
.
toindex
(
num_nodes
),
utils
.
toindex
(
start_nid
),
gidx
=
slice_gidx
(
utils
.
toindex
(
num_edges
),
utils
.
toindex
(
start_eid
))
g
.
_graph
,
utils
.
toindex
(
num_nodes
),
utils
.
toindex
(
start_nid
),
utils
.
toindex
(
num_edges
),
utils
.
toindex
(
start_eid
),
)
retg
=
DGLGraph
(
gidx
,
g
.
ntypes
,
g
.
etypes
)
retg
=
DGLGraph
(
gidx
,
g
.
ntypes
,
g
.
etypes
)
# Slice node features
# Slice node features
for
ntid
,
ntype
in
enumerate
(
g
.
ntypes
):
for
ntid
,
ntype
in
enumerate
(
g
.
ntypes
):
stnid
=
start_nid
[
ntid
]
stnid
=
start_nid
[
ntid
]
for
key
,
feat
in
g
.
nodes
[
ntype
].
data
.
items
():
for
key
,
feat
in
g
.
nodes
[
ntype
].
data
.
items
():
subfeats
=
F
.
slice_axis
(
feat
,
0
,
stnid
,
stnid
+
num_nodes
[
ntid
])
subfeats
=
F
.
slice_axis
(
feat
,
0
,
stnid
,
stnid
+
num_nodes
[
ntid
])
retg
.
nodes
[
ntype
].
data
[
key
]
=
subfeats
retg
.
nodes
[
ntype
].
data
[
key
]
=
subfeats
if
store_ids
:
if
store_ids
:
retg
.
nodes
[
ntype
].
data
[
NID
]
=
F
.
arange
(
stnid
,
stnid
+
num_nodes
[
ntid
],
retg
.
nodes
[
ntype
].
data
[
NID
]
=
F
.
arange
(
retg
.
idtype
,
retg
.
device
)
stnid
,
stnid
+
num_nodes
[
ntid
],
retg
.
idtype
,
retg
.
device
)
# Slice edge features
# Slice edge features
for
etid
,
etype
in
enumerate
(
g
.
canonical_etypes
):
for
etid
,
etype
in
enumerate
(
g
.
canonical_etypes
):
steid
=
start_eid
[
etid
]
steid
=
start_eid
[
etid
]
for
key
,
feat
in
g
.
edges
[
etype
].
data
.
items
():
for
key
,
feat
in
g
.
edges
[
etype
].
data
.
items
():
subfeats
=
F
.
slice_axis
(
feat
,
0
,
steid
,
steid
+
num_edges
[
etid
])
subfeats
=
F
.
slice_axis
(
feat
,
0
,
steid
,
steid
+
num_edges
[
etid
])
retg
.
edges
[
etype
].
data
[
key
]
=
subfeats
retg
.
edges
[
etype
].
data
[
key
]
=
subfeats
if
store_ids
:
if
store_ids
:
retg
.
edges
[
etype
].
data
[
EID
]
=
F
.
arange
(
steid
,
steid
+
num_edges
[
etid
],
retg
.
edges
[
etype
].
data
[
EID
]
=
F
.
arange
(
retg
.
idtype
,
retg
.
device
)
steid
,
steid
+
num_edges
[
etid
],
retg
.
idtype
,
retg
.
device
)
return
retg
return
retg
python/dgl/convert.py
View file @
a566b60b
"""Module for converting graph from/to other object."""
"""Module for converting graph from/to other object."""
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
scipy.sparse
import
spmatrix
import
numpy
as
np
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
from
scipy.sparse
import
spmatrix
from
.
import
backend
as
F
from
.
import
backend
as
F
,
graph_index
,
heterograph_index
,
utils
from
.
import
heterograph_index
from
.base
import
DGLError
,
EID
,
ETYPE
,
NID
,
NTYPE
from
.heterograph
import
DGLGraph
,
combine_frames
,
DGLBlock
from
.heterograph
import
combine_frames
,
DGLBlock
,
DGLGraph
from
.
import
graph_index
from
.
import
utils
from
.base
import
NTYPE
,
ETYPE
,
NID
,
EID
,
DGLError
__all__
=
[
__all__
=
[
'
graph
'
,
"
graph
"
,
'
hetero_from_shared_memory
'
,
"
hetero_from_shared_memory
"
,
'
heterograph
'
,
"
heterograph
"
,
'
create_block
'
,
"
create_block
"
,
'
block_to_graph
'
,
"
block_to_graph
"
,
'
to_heterogeneous
'
,
"
to_heterogeneous
"
,
'
to_homogeneous
'
,
"
to_homogeneous
"
,
'
from_scipy
'
,
"
from_scipy
"
,
'
bipartite_from_scipy
'
,
"
bipartite_from_scipy
"
,
'
from_networkx
'
,
"
from_networkx
"
,
'
bipartite_from_networkx
'
,
"
bipartite_from_networkx
"
,
'
to_networkx
'
,
"
to_networkx
"
,
'
from_cugraph
'
,
"
from_cugraph
"
,
'
to_cugraph
'
"
to_cugraph
"
,
]
]
def
graph
(
data
,
*
,
def
graph
(
num_nodes
=
None
,
data
,
idtype
=
None
,
*
,
device
=
None
,
num_nodes
=
None
,
row_sorted
=
False
,
idtype
=
None
,
col_sorted
=
False
):
device
=
None
,
row_sorted
=
False
,
col_sorted
=
False
,
):
"""Create a graph and return.
"""Create a graph and return.
Parameters
Parameters
...
@@ -147,25 +148,41 @@ def graph(data,
...
@@ -147,25 +148,41 @@ def graph(data,
from_networkx
from_networkx
"""
"""
if
isinstance
(
data
,
spmatrix
):
if
isinstance
(
data
,
spmatrix
):
raise
DGLError
(
"dgl.graph no longer supports graph construction from a SciPy "
raise
DGLError
(
"sparse matrix, use dgl.from_scipy instead."
)
"dgl.graph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if
isinstance
(
data
,
nx
.
Graph
):
if
isinstance
(
data
,
nx
.
Graph
):
raise
DGLError
(
"dgl.graph no longer supports graph construction from a NetworkX "
raise
DGLError
(
"graph, use dgl.from_networkx instead."
)
"dgl.graph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead."
)
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
)
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
)
if
num_nodes
is
not
None
:
# override the number of nodes
if
num_nodes
is
not
None
:
# override the number of nodes
if
num_nodes
<
max
(
urange
,
vrange
):
if
num_nodes
<
max
(
urange
,
vrange
):
raise
DGLError
(
'The num_nodes argument must be larger than the max ID in the data,'
raise
DGLError
(
' but got {} and {}.'
.
format
(
num_nodes
,
max
(
urange
,
vrange
)
-
1
))
"The num_nodes argument must be larger than the max ID in the data,"
" but got {} and {}."
.
format
(
num_nodes
,
max
(
urange
,
vrange
)
-
1
)
)
urange
,
vrange
=
num_nodes
,
num_nodes
urange
,
vrange
=
num_nodes
,
num_nodes
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
'_N'
,
'_E'
,
'_N'
,
urange
,
vrange
,
g
=
create_from_edges
(
row_sorted
=
row_sorted
,
col_sorted
=
col_sorted
)
sparse_fmt
,
arrays
,
"_N"
,
"_E"
,
"_N"
,
urange
,
vrange
,
row_sorted
=
row_sorted
,
col_sorted
=
col_sorted
,
)
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
hetero_from_shared_memory
(
name
):
def
hetero_from_shared_memory
(
name
):
"""Create a heterograph from shared memory with the given name.
"""Create a heterograph from shared memory with the given name.
...
@@ -181,13 +198,13 @@ def hetero_from_shared_memory(name):
...
@@ -181,13 +198,13 @@ def hetero_from_shared_memory(name):
-------
-------
HeteroGraph (in shared memory)
HeteroGraph (in shared memory)
"""
"""
g
,
ntypes
,
etypes
=
heterograph_index
.
create_heterograph_from_shared_memory
(
name
)
g
,
ntypes
,
etypes
=
heterograph_index
.
create_heterograph_from_shared_memory
(
name
)
return
DGLGraph
(
g
,
ntypes
,
etypes
)
return
DGLGraph
(
g
,
ntypes
,
etypes
)
def
heterograph
(
data_dict
,
num_nodes_dict
=
None
,
def
heterograph
(
data_dict
,
num_nodes_dict
=
None
,
idtype
=
None
,
device
=
None
):
idtype
=
None
,
device
=
None
):
"""Create a heterogeneous graph and return.
"""Create a heterogeneous graph and return.
Parameters
Parameters
...
@@ -300,47 +317,77 @@ def heterograph(data_dict,
...
@@ -300,47 +317,77 @@ def heterograph(data_dict,
num_nodes_dict
=
defaultdict
(
int
)
num_nodes_dict
=
defaultdict
(
int
)
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
if
isinstance
(
data
,
spmatrix
):
if
isinstance
(
data
,
spmatrix
):
raise
DGLError
(
"dgl.heterograph no longer supports graph construction from a SciPy "
raise
DGLError
(
"sparse matrix, use dgl.from_scipy instead."
)
"dgl.heterograph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead."
)
if
isinstance
(
data
,
nx
.
Graph
):
if
isinstance
(
data
,
nx
.
Graph
):
raise
DGLError
(
"dgl.heterograph no longer supports graph construction from a NetworkX "
raise
DGLError
(
"graph, use dgl.from_networkx instead."
)
"dgl.heterograph no longer supports graph construction from a NetworkX "
is_bipartite
=
(
sty
!=
dty
)
"graph, use dgl.from_networkx instead."
)
is_bipartite
=
sty
!=
dty
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
,
bipartite
=
is_bipartite
)
data
,
idtype
,
bipartite
=
is_bipartite
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
if
need_infer
:
if
need_infer
:
num_nodes_dict
[
sty
]
=
max
(
num_nodes_dict
[
sty
],
urange
)
num_nodes_dict
[
sty
]
=
max
(
num_nodes_dict
[
sty
],
urange
)
num_nodes_dict
[
dty
]
=
max
(
num_nodes_dict
[
dty
],
vrange
)
num_nodes_dict
[
dty
]
=
max
(
num_nodes_dict
[
dty
],
vrange
)
else
:
# sanity check
else
:
# sanity check
if
num_nodes_dict
[
sty
]
<
urange
:
if
num_nodes_dict
[
sty
]
<
urange
:
raise
DGLError
(
'The given number of nodes of node type {} must be larger than'
raise
DGLError
(
' the max ID in the data, but got {} and {}.'
.
format
(
"The given number of nodes of node type {} must be larger than"
sty
,
num_nodes_dict
[
sty
],
urange
-
1
))
" the max ID in the data, but got {} and {}."
.
format
(
sty
,
num_nodes_dict
[
sty
],
urange
-
1
)
)
if
num_nodes_dict
[
dty
]
<
vrange
:
if
num_nodes_dict
[
dty
]
<
vrange
:
raise
DGLError
(
'The given number of nodes of node type {} must be larger than'
raise
DGLError
(
' the max ID in the data, but got {} and {}.'
.
format
(
"The given number of nodes of node type {} must be larger than"
dty
,
num_nodes_dict
[
dty
],
vrange
-
1
))
" the max ID in the data, but got {} and {}."
.
format
(
dty
,
num_nodes_dict
[
dty
],
vrange
-
1
)
)
# Create the graph
# Create the graph
metagraph
,
ntypes
,
etypes
,
relations
=
heterograph_index
.
create_metagraph_index
(
(
num_nodes_dict
.
keys
(),
node_tensor_dict
.
keys
())
metagraph
,
num_nodes_per_type
=
utils
.
toindex
([
num_nodes_dict
[
ntype
]
for
ntype
in
ntypes
],
"int64"
)
ntypes
,
etypes
,
relations
,
)
=
heterograph_index
.
create_metagraph_index
(
num_nodes_dict
.
keys
(),
node_tensor_dict
.
keys
()
)
num_nodes_per_type
=
utils
.
toindex
(
[
num_nodes_dict
[
ntype
]
for
ntype
in
ntypes
],
"int64"
)
rel_graphs
=
[]
rel_graphs
=
[]
for
srctype
,
etype
,
dsttype
in
relations
:
for
srctype
,
etype
,
dsttype
in
relations
:
sparse_fmt
,
arrays
=
node_tensor_dict
[(
srctype
,
etype
,
dsttype
)]
sparse_fmt
,
arrays
=
node_tensor_dict
[(
srctype
,
etype
,
dsttype
)]
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
srctype
,
etype
,
dsttype
,
g
=
create_from_edges
(
num_nodes_dict
[
srctype
],
num_nodes_dict
[
dsttype
])
sparse_fmt
,
arrays
,
srctype
,
etype
,
dsttype
,
num_nodes_dict
[
srctype
],
num_nodes_dict
[
dsttype
],
)
rel_graphs
.
append
(
g
)
rel_graphs
.
append
(
g
)
# create graph index
# create graph index
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
[
rgrh
.
_graph
for
rgrh
in
rel_graphs
],
num_nodes_per_type
)
metagraph
,
[
rgrh
.
_graph
for
rgrh
in
rel_graphs
],
num_nodes_per_type
)
retg
=
DGLGraph
(
hgidx
,
ntypes
,
etypes
)
retg
=
DGLGraph
(
hgidx
,
ntypes
,
etypes
)
return
retg
.
to
(
device
)
return
retg
.
to
(
device
)
def
create_block
(
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
):
def
create_block
(
data_dict
,
num_src_nodes
=
None
,
num_dst_nodes
=
None
,
idtype
=
None
,
device
=
None
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
Parameters
Parameters
...
@@ -464,21 +511,25 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
...
@@ -464,21 +511,25 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
"""
"""
need_infer
=
num_src_nodes
is
None
and
num_dst_nodes
is
None
need_infer
=
num_src_nodes
is
None
and
num_dst_nodes
is
None
if
not
isinstance
(
data_dict
,
Mapping
):
if
not
isinstance
(
data_dict
,
Mapping
):
data_dict
=
{(
'
_N
'
,
'
_E
'
,
'
_N
'
):
data_dict
}
data_dict
=
{(
"
_N
"
,
"
_E
"
,
"
_N
"
):
data_dict
}
if
not
need_infer
:
if
not
need_infer
:
assert
isinstance
(
num_src_nodes
,
int
),
\
assert
isinstance
(
"num_src_nodes must be a pair of integers if data_dict is not a dict"
num_src_nodes
,
int
assert
isinstance
(
num_dst_nodes
,
int
),
\
),
"num_src_nodes must be a pair of integers if data_dict is not a dict"
"num_dst_nodes must be a pair of integers if data_dict is not a dict"
assert
isinstance
(
num_src_nodes
=
{
'_N'
:
num_src_nodes
}
num_dst_nodes
,
int
num_dst_nodes
=
{
'_N'
:
num_dst_nodes
}
),
"num_dst_nodes must be a pair of integers if data_dict is not a dict"
num_src_nodes
=
{
"_N"
:
num_src_nodes
}
num_dst_nodes
=
{
"_N"
:
num_dst_nodes
}
else
:
else
:
if
not
need_infer
:
if
not
need_infer
:
assert
isinstance
(
num_src_nodes
,
Mapping
),
\
assert
isinstance
(
"num_src_nodes must be a dict if data_dict is a dict"
num_src_nodes
,
Mapping
assert
isinstance
(
num_dst_nodes
,
Mapping
),
\
),
"num_src_nodes must be a dict if data_dict is a dict"
"num_dst_nodes must be a dict if data_dict is a dict"
assert
isinstance
(
num_dst_nodes
,
Mapping
),
"num_dst_nodes must be a dict if data_dict is a dict"
if
need_infer
:
if
need_infer
:
num_src_nodes
=
defaultdict
(
int
)
num_src_nodes
=
defaultdict
(
int
)
...
@@ -488,20 +539,27 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
...
@@ -488,20 +539,27 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
node_tensor_dict
=
{}
node_tensor_dict
=
{}
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
for
(
sty
,
ety
,
dty
),
data
in
data_dict
.
items
():
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
data
,
idtype
,
bipartite
=
True
)
data
,
idtype
,
bipartite
=
True
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
node_tensor_dict
[(
sty
,
ety
,
dty
)]
=
(
sparse_fmt
,
arrays
)
if
need_infer
:
if
need_infer
:
num_src_nodes
[
sty
]
=
max
(
num_src_nodes
[
sty
],
urange
)
num_src_nodes
[
sty
]
=
max
(
num_src_nodes
[
sty
],
urange
)
num_dst_nodes
[
dty
]
=
max
(
num_dst_nodes
[
dty
],
vrange
)
num_dst_nodes
[
dty
]
=
max
(
num_dst_nodes
[
dty
],
vrange
)
else
:
# sanity check
else
:
# sanity check
if
num_src_nodes
[
sty
]
<
urange
:
if
num_src_nodes
[
sty
]
<
urange
:
raise
DGLError
(
'The given number of nodes of source node type {} must be larger'
raise
DGLError
(
' than the max ID in the data, but got {} and {}.'
.
format
(
"The given number of nodes of source node type {} must be larger"
sty
,
num_src_nodes
[
sty
],
urange
-
1
))
" than the max ID in the data, but got {} and {}."
.
format
(
sty
,
num_src_nodes
[
sty
],
urange
-
1
)
)
if
num_dst_nodes
[
dty
]
<
vrange
:
if
num_dst_nodes
[
dty
]
<
vrange
:
raise
DGLError
(
'The given number of nodes of destination node type {} must be'
raise
DGLError
(
' larger than the max ID in the data, but got {} and {}.'
.
format
(
"The given number of nodes of destination node type {} must be"
dty
,
num_dst_nodes
[
dty
],
vrange
-
1
))
" larger than the max ID in the data, but got {} and {}."
.
format
(
dty
,
num_dst_nodes
[
dty
],
vrange
-
1
)
)
# Create the graph
# Create the graph
# Sort the ntypes and relation tuples to have a deterministic order for the same set
# Sort the ntypes and relation tuples to have a deterministic order for the same set
...
@@ -511,10 +569,14 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
...
@@ -511,10 +569,14 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
relations
=
list
(
sorted
(
node_tensor_dict
.
keys
()))
relations
=
list
(
sorted
(
node_tensor_dict
.
keys
()))
num_nodes_per_type
=
utils
.
toindex
(
num_nodes_per_type
=
utils
.
toindex
(
[
num_src_nodes
[
ntype
]
for
ntype
in
srctypes
]
+
[
num_src_nodes
[
ntype
]
for
ntype
in
srctypes
]
[
num_dst_nodes
[
ntype
]
for
ntype
in
dsttypes
],
"int64"
)
+
[
num_dst_nodes
[
ntype
]
for
ntype
in
dsttypes
],
"int64"
,
)
srctype_dict
=
{
ntype
:
i
for
i
,
ntype
in
enumerate
(
srctypes
)}
srctype_dict
=
{
ntype
:
i
for
i
,
ntype
in
enumerate
(
srctypes
)}
dsttype_dict
=
{
ntype
:
i
+
len
(
srctypes
)
for
i
,
ntype
in
enumerate
(
dsttypes
)}
dsttype_dict
=
{
ntype
:
i
+
len
(
srctypes
)
for
i
,
ntype
in
enumerate
(
dsttypes
)
}
meta_edges_src
=
[]
meta_edges_src
=
[]
meta_edges_dst
=
[]
meta_edges_dst
=
[]
...
@@ -525,20 +587,30 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
...
@@ -525,20 +587,30 @@ def create_block(data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None,
meta_edges_dst
.
append
(
dsttype_dict
[
dsttype
])
meta_edges_dst
.
append
(
dsttype_dict
[
dsttype
])
etypes
.
append
(
etype
)
etypes
.
append
(
etype
)
sparse_fmt
,
arrays
=
node_tensor_dict
[(
srctype
,
etype
,
dsttype
)]
sparse_fmt
,
arrays
=
node_tensor_dict
[(
srctype
,
etype
,
dsttype
)]
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
'SRC/'
+
srctype
,
etype
,
'DST/'
+
dsttype
,
g
=
create_from_edges
(
num_src_nodes
[
srctype
],
num_dst_nodes
[
dsttype
])
sparse_fmt
,
arrays
,
"SRC/"
+
srctype
,
etype
,
"DST/"
+
dsttype
,
num_src_nodes
[
srctype
],
num_dst_nodes
[
dsttype
],
)
rel_graphs
.
append
(
g
)
rel_graphs
.
append
(
g
)
# metagraph is DGLGraph, currently still using int64 as index dtype
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph
=
graph_index
.
from_coo
(
metagraph
=
graph_index
.
from_coo
(
len
(
srctypes
)
+
len
(
dsttypes
),
meta_edges_src
,
meta_edges_dst
,
True
)
len
(
srctypes
)
+
len
(
dsttypes
),
meta_edges_src
,
meta_edges_dst
,
True
)
# create graph index
# create graph index
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
[
rgrh
.
_graph
for
rgrh
in
rel_graphs
],
num_nodes_per_type
)
metagraph
,
[
rgrh
.
_graph
for
rgrh
in
rel_graphs
],
num_nodes_per_type
)
retg
=
DGLBlock
(
hgidx
,
(
srctypes
,
dsttypes
),
etypes
)
retg
=
DGLBlock
(
hgidx
,
(
srctypes
,
dsttypes
),
etypes
)
return
retg
.
to
(
device
)
return
retg
.
to
(
device
)
def
block_to_graph
(
block
):
def
block_to_graph
(
block
):
"""Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`.
"""Convert a message flow graph (MFG) as a :class:`DGLBlock` object to a :class:`DGLGraph`.
...
@@ -568,22 +640,26 @@ def block_to_graph(block):
...
@@ -568,22 +640,26 @@ def block_to_graph(block):
num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2},
num_edges={('A_src', 'AB', 'B_dst'): 3, ('B_src', 'BA', 'A_dst'): 2},
metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')])
metagraph=[('A_src', 'B_dst', 'AB'), ('B_src', 'A_dst', 'BA')])
"""
"""
new_types
=
[
ntype
+
'_src'
for
ntype
in
block
.
srctypes
]
+
\
new_types
=
[
ntype
+
"_src"
for
ntype
in
block
.
srctypes
]
+
[
[
ntype
+
'_dst'
for
ntype
in
block
.
dsttypes
]
ntype
+
"_dst"
for
ntype
in
block
.
dsttypes
]
retg
=
DGLGraph
(
block
.
_graph
,
new_types
,
block
.
etypes
)
retg
=
DGLGraph
(
block
.
_graph
,
new_types
,
block
.
etypes
)
for
srctype
in
block
.
srctypes
:
for
srctype
in
block
.
srctypes
:
retg
.
nodes
[
srctype
+
'
_src
'
].
data
.
update
(
block
.
srcnodes
[
srctype
].
data
)
retg
.
nodes
[
srctype
+
"
_src
"
].
data
.
update
(
block
.
srcnodes
[
srctype
].
data
)
for
dsttype
in
block
.
dsttypes
:
for
dsttype
in
block
.
dsttypes
:
retg
.
nodes
[
dsttype
+
'
_dst
'
].
data
.
update
(
block
.
dstnodes
[
dsttype
].
data
)
retg
.
nodes
[
dsttype
+
"
_dst
"
].
data
.
update
(
block
.
dstnodes
[
dsttype
].
data
)
for
srctype
,
etype
,
dsttype
in
block
.
canonical_etypes
:
for
srctype
,
etype
,
dsttype
in
block
.
canonical_etypes
:
retg
.
edges
[
srctype
+
'_src'
,
etype
,
dsttype
+
'_dst'
].
data
.
update
(
retg
.
edges
[
srctype
+
"_src"
,
etype
,
dsttype
+
"_dst"
].
data
.
update
(
block
.
edges
[
srctype
,
etype
,
dsttype
].
data
)
block
.
edges
[
srctype
,
etype
,
dsttype
].
data
)
return
retg
return
retg
def
to_heterogeneous
(
G
,
ntypes
,
etypes
,
ntype_field
=
NTYPE
,
etype_field
=
ETYPE
,
metagraph
=
None
):
def
to_heterogeneous
(
G
,
ntypes
,
etypes
,
ntype_field
=
NTYPE
,
etype_field
=
ETYPE
,
metagraph
=
None
):
"""Convert a homogeneous graph to a heterogeneous graph and return.
"""Convert a homogeneous graph to a heterogeneous graph and return.
The input graph should have only one type of nodes and edges. Each node and edge
The input graph should have only one type of nodes and edges. Each node and edge
...
@@ -691,10 +767,16 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
...
@@ -691,10 +767,16 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
--------
--------
to_homogeneous
to_homogeneous
"""
"""
if
(
hasattr
(
G
,
'ntypes'
)
and
len
(
G
.
ntypes
)
>
1
if
(
or
hasattr
(
G
,
'etypes'
)
and
len
(
G
.
etypes
)
>
1
):
hasattr
(
G
,
"ntypes"
)
raise
DGLError
(
'The input graph should be homogeneous and have only one '
and
len
(
G
.
ntypes
)
>
1
' type of nodes and edges.'
)
or
hasattr
(
G
,
"etypes"
)
and
len
(
G
.
etypes
)
>
1
):
raise
DGLError
(
"The input graph should be homogeneous and have only one "
" type of nodes and edges."
)
num_ntypes
=
len
(
ntypes
)
num_ntypes
=
len
(
ntypes
)
idtype
=
G
.
idtype
idtype
=
G
.
idtype
...
@@ -706,15 +788,15 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
...
@@ -706,15 +788,15 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# relabel nodes to per-type local IDs
# relabel nodes to per-type local IDs
ntype_count
=
np
.
bincount
(
ntype_ids
,
minlength
=
num_ntypes
)
ntype_count
=
np
.
bincount
(
ntype_ids
,
minlength
=
num_ntypes
)
ntype_offset
=
np
.
insert
(
np
.
cumsum
(
ntype_count
),
0
,
0
)
ntype_offset
=
np
.
insert
(
np
.
cumsum
(
ntype_count
),
0
,
0
)
ntype_ids_sortidx
=
np
.
argsort
(
ntype_ids
,
kind
=
'
stable
'
)
ntype_ids_sortidx
=
np
.
argsort
(
ntype_ids
,
kind
=
"
stable
"
)
ntype_local_ids
=
np
.
zeros_like
(
ntype_ids
)
ntype_local_ids
=
np
.
zeros_like
(
ntype_ids
)
node_groups
=
[]
node_groups
=
[]
for
i
in
range
(
num_ntypes
):
for
i
in
range
(
num_ntypes
):
node_group
=
ntype_ids_sortidx
[
ntype_offset
[
i
]
:
ntype_offset
[
i
+
1
]]
node_group
=
ntype_ids_sortidx
[
ntype_offset
[
i
]
:
ntype_offset
[
i
+
1
]]
node_groups
.
append
(
node_group
)
node_groups
.
append
(
node_group
)
ntype_local_ids
[
node_group
]
=
np
.
arange
(
ntype_count
[
i
])
ntype_local_ids
[
node_group
]
=
np
.
arange
(
ntype_count
[
i
])
src
,
dst
=
G
.
all_edges
(
order
=
'
eid
'
)
src
,
dst
=
G
.
all_edges
(
order
=
"
eid
"
)
src
=
F
.
asnumpy
(
src
)
src
=
F
.
asnumpy
(
src
)
dst
=
F
.
asnumpy
(
dst
)
dst
=
F
.
asnumpy
(
dst
)
src_local
=
ntype_local_ids
[
src
]
src_local
=
ntype_local_ids
[
src
]
...
@@ -729,21 +811,28 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
...
@@ -729,21 +811,28 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
# above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the
# above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the
# canonical edge type j. We can then group the edges of the same type together.
# canonical edge type j. We can then group the edges of the same type together.
if
metagraph
is
None
:
if
metagraph
is
None
:
canonical_etids
,
_
,
etype_remapped
=
\
canonical_etids
,
_
,
etype_remapped
=
utils
.
make_invmap
(
utils
.
make_invmap
(
list
(
tuple
(
_
)
for
_
in
edge_ctids
),
False
)
list
(
tuple
(
_
)
for
_
in
edge_ctids
),
False
etype_mask
=
(
etype_remapped
[
None
,
:]
==
np
.
arange
(
len
(
canonical_etids
))[:,
None
])
)
etype_mask
=
(
etype_remapped
[
None
,
:]
==
np
.
arange
(
len
(
canonical_etids
))[:,
None
]
)
else
:
else
:
ntypes_invmap
=
{
nt
:
i
for
i
,
nt
in
enumerate
(
ntypes
)}
ntypes_invmap
=
{
nt
:
i
for
i
,
nt
in
enumerate
(
ntypes
)}
etypes_invmap
=
{
et
:
i
for
i
,
et
in
enumerate
(
etypes
)}
etypes_invmap
=
{
et
:
i
for
i
,
et
in
enumerate
(
etypes
)}
canonical_etids
=
[]
canonical_etids
=
[]
for
i
,
(
srctype
,
dsttype
,
etype
)
in
enumerate
(
metagraph
.
edges
(
keys
=
True
)):
for
i
,
(
srctype
,
dsttype
,
etype
)
in
enumerate
(
metagraph
.
edges
(
keys
=
True
)
):
srctype_id
=
ntypes_invmap
[
srctype
]
srctype_id
=
ntypes_invmap
[
srctype
]
etype_id
=
etypes_invmap
[
etype
]
etype_id
=
etypes_invmap
[
etype
]
dsttype_id
=
ntypes_invmap
[
dsttype
]
dsttype_id
=
ntypes_invmap
[
dsttype
]
canonical_etids
.
append
((
srctype_id
,
etype_id
,
dsttype_id
))
canonical_etids
.
append
((
srctype_id
,
etype_id
,
dsttype_id
))
canonical_etids
=
np
.
asarray
(
canonical_etids
)
canonical_etids
=
np
.
asarray
(
canonical_etids
)
etype_mask
=
(
edge_ctids
[
None
,
:]
==
canonical_etids
[:,
None
]).
all
(
2
)
etype_mask
=
(
edge_ctids
[
None
,
:]
==
canonical_etids
[:,
None
]).
all
(
2
)
edge_groups
=
[
etype_mask
[
i
].
nonzero
()[
0
]
for
i
in
range
(
len
(
canonical_etids
))]
edge_groups
=
[
etype_mask
[
i
].
nonzero
()[
0
]
for
i
in
range
(
len
(
canonical_etids
))
]
data_dict
=
dict
()
data_dict
=
dict
()
canonical_etypes
=
[]
canonical_etypes
=
[]
...
@@ -751,13 +840,12 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
...
@@ -751,13 +840,12 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
src_of_etype
=
src_local
[
edge_groups
[
i
]]
src_of_etype
=
src_local
[
edge_groups
[
i
]]
dst_of_etype
=
dst_local
[
edge_groups
[
i
]]
dst_of_etype
=
dst_local
[
edge_groups
[
i
]]
canonical_etypes
.
append
((
ntypes
[
stid
],
etypes
[
etid
],
ntypes
[
dtid
]))
canonical_etypes
.
append
((
ntypes
[
stid
],
etypes
[
etid
],
ntypes
[
dtid
]))
data_dict
[
canonical_etypes
[
-
1
]]
=
\
data_dict
[
canonical_etypes
[
-
1
]]
=
(
src_of_etype
,
dst_of_etype
)
(
src_of_etype
,
dst_of_etype
)
hg
=
heterograph
(
hg
=
heterograph
(
data_dict
,
data_dict
,
dict
(
zip
(
ntypes
,
ntype_count
)),
idtype
=
idtype
,
device
=
device
dict
(
zip
(
ntypes
,
ntype_count
)),
)
idtype
=
idtype
,
device
=
device
)
ntype2ngrp
=
{
ntype
:
node_groups
[
ntid
]
for
ntid
,
ntype
in
enumerate
(
ntypes
)}
ntype2ngrp
=
{
ntype
:
node_groups
[
ntid
]
for
ntid
,
ntype
in
enumerate
(
ntypes
)}
# features
# features
for
key
,
data
in
G
.
ndata
.
items
():
for
key
,
data
in
G
.
ndata
.
items
():
...
@@ -772,19 +860,26 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
...
@@ -772,19 +860,26 @@ def to_heterogeneous(G, ntypes, etypes, ntype_field=NTYPE,
continue
continue
for
etid
in
range
(
len
(
hg
.
canonical_etypes
)):
for
etid
in
range
(
len
(
hg
.
canonical_etypes
)):
rows
=
F
.
copy_to
(
F
.
tensor
(
edge_groups
[
etid
]),
F
.
context
(
data
))
rows
=
F
.
copy_to
(
F
.
tensor
(
edge_groups
[
etid
]),
F
.
context
(
data
))
hg
.
_edge_frames
[
hg
.
get_etype_id
(
canonical_etypes
[
etid
])][
key
]
=
\
hg
.
_edge_frames
[
hg
.
get_etype_id
(
canonical_etypes
[
etid
])][
F
.
gather_row
(
data
,
rows
)
key
]
=
F
.
gather_row
(
data
,
rows
)
# Record the original IDs of the nodes/edges
# Record the original IDs of the nodes/edges
for
ntid
,
ntype
in
enumerate
(
hg
.
ntypes
):
for
ntid
,
ntype
in
enumerate
(
hg
.
ntypes
):
hg
.
_node_frames
[
ntid
][
NID
]
=
F
.
copy_to
(
F
.
tensor
(
ntype2ngrp
[
ntype
]),
device
)
hg
.
_node_frames
[
ntid
][
NID
]
=
F
.
copy_to
(
F
.
tensor
(
ntype2ngrp
[
ntype
]),
device
)
for
etid
in
range
(
len
(
hg
.
canonical_etypes
)):
for
etid
in
range
(
len
(
hg
.
canonical_etypes
)):
hg
.
_edge_frames
[
hg
.
get_etype_id
(
canonical_etypes
[
etid
])][
EID
]
=
\
hg
.
_edge_frames
[
hg
.
get_etype_id
(
canonical_etypes
[
etid
])][
F
.
copy_to
(
F
.
tensor
(
edge_groups
[
etid
]),
device
)
EID
]
=
F
.
copy_to
(
F
.
tensor
(
edge_groups
[
etid
]),
device
)
return
hg
return
hg
def
to_homogeneous
(
G
,
ndata
=
None
,
edata
=
None
,
store_type
=
True
,
return_count
=
False
):
def
to_homogeneous
(
G
,
ndata
=
None
,
edata
=
None
,
store_type
=
True
,
return_count
=
False
):
"""Convert a heterogeneous graph to a homogeneous graph and return.
"""Convert a heterogeneous graph to a homogeneous graph and return.
By default, the function stores the node and edge types of the input graph as
By default, the function stores the node and edge types of the input graph as
...
@@ -902,7 +997,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
...
@@ -902,7 +997,7 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
for
etype_id
,
etype
in
enumerate
(
G
.
canonical_etypes
):
for
etype_id
,
etype
in
enumerate
(
G
.
canonical_etypes
):
srctype
,
_
,
dsttype
=
etype
srctype
,
_
,
dsttype
=
etype
src
,
dst
=
G
.
all_edges
(
etype
=
etype
,
order
=
'
eid
'
)
src
,
dst
=
G
.
all_edges
(
etype
=
etype
,
order
=
"
eid
"
)
num_edges
=
len
(
src
)
num_edges
=
len
(
src
)
srcs
.
append
(
src
+
int
(
offset_per_ntype
[
G
.
get_ntype_id
(
srctype
)]))
srcs
.
append
(
src
+
int
(
offset_per_ntype
[
G
.
get_ntype_id
(
srctype
)]))
dsts
.
append
(
dst
+
int
(
offset_per_ntype
[
G
.
get_ntype_id
(
dsttype
)]))
dsts
.
append
(
dst
+
int
(
offset_per_ntype
[
G
.
get_ntype_id
(
dsttype
)]))
...
@@ -913,16 +1008,24 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
...
@@ -913,16 +1008,24 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
etype_count
.
append
(
num_edges
)
etype_count
.
append
(
num_edges
)
eids
.
append
(
F
.
arange
(
0
,
num_edges
,
G
.
idtype
,
G
.
device
))
eids
.
append
(
F
.
arange
(
0
,
num_edges
,
G
.
idtype
,
G
.
device
))
retg
=
graph
((
F
.
cat
(
srcs
,
0
),
F
.
cat
(
dsts
,
0
)),
num_nodes
=
total_num_nodes
,
retg
=
graph
(
idtype
=
G
.
idtype
,
device
=
G
.
device
)
(
F
.
cat
(
srcs
,
0
),
F
.
cat
(
dsts
,
0
)),
num_nodes
=
total_num_nodes
,
idtype
=
G
.
idtype
,
device
=
G
.
device
,
)
# copy features
# copy features
if
ndata
is
None
:
if
ndata
is
None
:
ndata
=
[]
ndata
=
[]
if
edata
is
None
:
if
edata
is
None
:
edata
=
[]
edata
=
[]
comb_nf
=
combine_frames
(
G
.
_node_frames
,
range
(
len
(
G
.
ntypes
)),
col_names
=
ndata
)
comb_nf
=
combine_frames
(
comb_ef
=
combine_frames
(
G
.
_edge_frames
,
range
(
len
(
G
.
etypes
)),
col_names
=
edata
)
G
.
_node_frames
,
range
(
len
(
G
.
ntypes
)),
col_names
=
ndata
)
comb_ef
=
combine_frames
(
G
.
_edge_frames
,
range
(
len
(
G
.
etypes
)),
col_names
=
edata
)
if
comb_nf
is
not
None
:
if
comb_nf
is
not
None
:
retg
.
ndata
.
update
(
comb_nf
)
retg
.
ndata
.
update
(
comb_nf
)
if
comb_ef
is
not
None
:
if
comb_ef
is
not
None
:
...
@@ -939,10 +1042,8 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
...
@@ -939,10 +1042,8 @@ def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=Fals
else
:
else
:
return
retg
return
retg
def
from_scipy
(
sp_mat
,
eweight_name
=
None
,
def
from_scipy
(
sp_mat
,
eweight_name
=
None
,
idtype
=
None
,
device
=
None
):
idtype
=
None
,
device
=
None
):
"""Create a graph from a SciPy sparse matrix and return.
"""Create a graph from a SciPy sparse matrix and return.
Parameters
Parameters
...
@@ -1019,20 +1120,23 @@ def from_scipy(sp_mat,
...
@@ -1019,20 +1120,23 @@ def from_scipy(sp_mat,
num_rows
=
sp_mat
.
shape
[
0
]
num_rows
=
sp_mat
.
shape
[
0
]
num_cols
=
sp_mat
.
shape
[
1
]
num_cols
=
sp_mat
.
shape
[
1
]
if
num_rows
!=
num_cols
:
if
num_rows
!=
num_cols
:
raise
DGLError
(
'Expect the number of rows to be the same as the number of columns for '
raise
DGLError
(
'sp_mat, got {:d} and {:d}.'
.
format
(
num_rows
,
num_cols
))
"Expect the number of rows to be the same as the number of columns for "
"sp_mat, got {:d} and {:d}."
.
format
(
num_rows
,
num_cols
)
)
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
sp_mat
,
idtype
)
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
'_N'
,
'_E'
,
'_N'
,
urange
,
vrange
)
sp_mat
,
idtype
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
"_N"
,
"_E"
,
"_N"
,
urange
,
vrange
)
if
eweight_name
is
not
None
:
if
eweight_name
is
not
None
:
g
.
edata
[
eweight_name
]
=
F
.
tensor
(
sp_mat
.
data
)
g
.
edata
[
eweight_name
]
=
F
.
tensor
(
sp_mat
.
data
)
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
bipartite_from_scipy
(
sp_mat
,
utype
,
etype
,
vtype
,
def
bipartite_from_scipy
(
eweight_name
=
None
,
sp_mat
,
utype
,
etype
,
vtype
,
eweight_name
=
None
,
idtype
=
None
,
device
=
None
idtype
=
None
,
):
device
=
None
):
"""Create a uni-directional bipartite graph from a SciPy sparse matrix and return.
"""Create a uni-directional bipartite graph from a SciPy sparse matrix and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
...
@@ -1116,18 +1220,25 @@ def bipartite_from_scipy(sp_mat,
...
@@ -1116,18 +1220,25 @@ def bipartite_from_scipy(sp_mat,
heterograph
heterograph
bipartite_from_networkx
bipartite_from_networkx
"""
"""
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
sp_mat
,
idtype
,
bipartite
=
True
)
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
sp_mat
,
idtype
,
bipartite
=
True
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
if
eweight_name
is
not
None
:
if
eweight_name
is
not
None
:
g
.
edata
[
eweight_name
]
=
F
.
tensor
(
sp_mat
.
data
)
g
.
edata
[
eweight_name
]
=
F
.
tensor
(
sp_mat
.
data
)
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
from_networkx
(
nx_graph
,
node_attrs
=
None
,
def
from_networkx
(
edge_attrs
=
None
,
nx_graph
,
edge_id_attr_name
=
None
,
node_attrs
=
None
,
idtype
=
None
,
edge_attrs
=
None
,
device
=
None
):
edge_id_attr_name
=
None
,
idtype
=
None
,
device
=
None
,
):
"""Create a graph from a NetworkX graph and return.
"""Create a graph from a NetworkX graph and return.
.. note::
.. note::
...
@@ -1221,27 +1332,38 @@ def from_networkx(nx_graph,
...
@@ -1221,27 +1332,38 @@ def from_networkx(nx_graph,
from_scipy
from_scipy
"""
"""
# Sanity check
# Sanity check
if
edge_id_attr_name
is
not
None
and
\
if
(
edge_id_attr_name
not
in
next
(
iter
(
nx_graph
.
edges
(
data
=
True
)))[
-
1
]:
edge_id_attr_name
is
not
None
raise
DGLError
(
'Failed to find the pre-specified edge IDs in the edge features of '
and
edge_id_attr_name
not
in
next
(
iter
(
nx_graph
.
edges
(
data
=
True
)))[
-
1
]
'the NetworkX graph with name {}'
.
format
(
edge_id_attr_name
))
):
raise
DGLError
(
if
not
nx_graph
.
is_directed
()
and
not
(
edge_id_attr_name
is
None
and
edge_attrs
is
None
):
"Failed to find the pre-specified edge IDs in the edge features of "
raise
DGLError
(
'Expect edge_id_attr_name and edge_attrs to be None when nx_graph is '
"the NetworkX graph with name {}"
.
format
(
edge_id_attr_name
)
'undirected, got {} and {}'
.
format
(
edge_id_attr_name
,
edge_attrs
))
)
if
not
nx_graph
.
is_directed
()
and
not
(
edge_id_attr_name
is
None
and
edge_attrs
is
None
):
raise
DGLError
(
"Expect edge_id_attr_name and edge_attrs to be None when nx_graph is "
"undirected, got {} and {}"
.
format
(
edge_id_attr_name
,
edge_attrs
)
)
# Relabel nodes using consecutive integers starting from 0
# Relabel nodes using consecutive integers starting from 0
nx_graph
=
nx
.
convert_node_labels_to_integers
(
nx_graph
,
ordering
=
'
sorted
'
)
nx_graph
=
nx
.
convert_node_labels_to_integers
(
nx_graph
,
ordering
=
"
sorted
"
)
if
not
nx_graph
.
is_directed
():
if
not
nx_graph
.
is_directed
():
nx_graph
=
nx_graph
.
to_directed
()
nx_graph
=
nx_graph
.
to_directed
()
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
nx_graph
,
idtype
,
edge_id_attr_name
=
edge_id_attr_name
)
nx_graph
,
idtype
,
edge_id_attr_name
=
edge_id_attr_name
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
'
_N
'
,
'
_E
'
,
'
_N
'
,
urange
,
vrange
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
"
_N
"
,
"
_E
"
,
"
_N
"
,
urange
,
vrange
)
# nx_graph.edges(data=True) returns src, dst, attr_dict
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id
=
nx_graph
.
number_of_edges
()
>
0
and
edge_id_attr_name
is
not
None
has_edge_id
=
(
nx_graph
.
number_of_edges
()
>
0
and
edge_id_attr_name
is
not
None
)
# handle features
# handle features
# copy attributes
# copy attributes
...
@@ -1250,6 +1372,7 @@ def from_networkx(nx_graph,
...
@@ -1250,6 +1372,7 @@ def from_networkx(nx_graph,
return
F
.
cat
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
],
dim
=
0
)
return
F
.
cat
([
F
.
unsqueeze
(
x
,
0
)
for
x
in
lst
],
dim
=
0
)
else
:
else
:
return
F
.
tensor
(
lst
)
return
F
.
tensor
(
lst
)
if
node_attrs
is
not
None
:
if
node_attrs
is
not
None
:
# mapping from feature name to a list of tensors to be concatenated
# mapping from feature name to a list of tensors to be concatenated
attr_dict
=
defaultdict
(
list
)
attr_dict
=
defaultdict
(
list
)
...
@@ -1269,9 +1392,11 @@ def from_networkx(nx_graph,
...
@@ -1269,9 +1392,11 @@ def from_networkx(nx_graph,
num_edges
=
g
.
number_of_edges
()
num_edges
=
g
.
number_of_edges
()
for
_
,
_
,
attrs
in
nx_graph
.
edges
(
data
=
True
):
for
_
,
_
,
attrs
in
nx_graph
.
edges
(
data
=
True
):
if
attrs
[
edge_id_attr_name
]
>=
num_edges
:
if
attrs
[
edge_id_attr_name
]
>=
num_edges
:
raise
DGLError
(
'Expect the pre-specified edge ids to be'
raise
DGLError
(
' smaller than the number of edges --'
"Expect the pre-specified edge ids to be"
' {}, got {}.'
.
format
(
num_edges
,
attrs
[
'id'
]))
" smaller than the number of edges --"
" {}, got {}."
.
format
(
num_edges
,
attrs
[
"id"
])
)
for
key
in
edge_attrs
:
for
key
in
edge_attrs
:
attr_dict
[
key
][
attrs
[
edge_id_attr_name
]]
=
attrs
[
key
]
attr_dict
[
key
][
attrs
[
edge_id_attr_name
]]
=
attrs
[
key
]
else
:
else
:
...
@@ -1283,17 +1408,26 @@ def from_networkx(nx_graph,
...
@@ -1283,17 +1408,26 @@ def from_networkx(nx_graph,
for
attr
in
edge_attrs
:
for
attr
in
edge_attrs
:
for
val
in
attr_dict
[
attr
]:
for
val
in
attr_dict
[
attr
]:
if
val
is
None
:
if
val
is
None
:
raise
DGLError
(
'Not all edges have attribute {}.'
.
format
(
attr
))
raise
DGLError
(
"Not all edges have attribute {}."
.
format
(
attr
)
)
g
.
edata
[
attr
]
=
F
.
copy_to
(
_batcher
(
attr_dict
[
attr
]),
g
.
device
)
g
.
edata
[
attr
]
=
F
.
copy_to
(
_batcher
(
attr_dict
[
attr
]),
g
.
device
)
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
bipartite_from_networkx
(
nx_graph
,
utype
,
etype
,
vtype
,
def
bipartite_from_networkx
(
u_attrs
=
None
,
e_attrs
=
None
,
v_attrs
=
None
,
nx_graph
,
edge_id_attr_name
=
None
,
utype
,
idtype
=
None
,
etype
,
device
=
None
):
vtype
,
u_attrs
=
None
,
e_attrs
=
None
,
v_attrs
=
None
,
edge_id_attr_name
=
None
,
idtype
=
None
,
device
=
None
,
):
"""Create a unidirectional bipartite graph from a NetworkX graph and return.
"""Create a unidirectional bipartite graph from a NetworkX graph and return.
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
The created graph will have two types of nodes ``utype`` and ``vtype`` as well as one
...
@@ -1403,42 +1537,58 @@ def bipartite_from_networkx(nx_graph,
...
@@ -1403,42 +1537,58 @@ def bipartite_from_networkx(nx_graph,
bipartite_from_scipy
bipartite_from_scipy
"""
"""
if
not
nx_graph
.
is_directed
():
if
not
nx_graph
.
is_directed
():
raise
DGLError
(
'Expect nx_graph to be a directed NetworkX graph.'
)
raise
DGLError
(
"Expect nx_graph to be a directed NetworkX graph."
)
if
edge_id_attr_name
is
not
None
and
\
if
(
not
edge_id_attr_name
in
next
(
iter
(
nx_graph
.
edges
(
data
=
True
)))[
-
1
]:
edge_id_attr_name
is
not
None
raise
DGLError
(
'Failed to find the pre-specified edge IDs in the edge features '
and
not
edge_id_attr_name
in
next
(
iter
(
nx_graph
.
edges
(
data
=
True
)))[
-
1
]
'of the NetworkX graph with name {}'
.
format
(
edge_id_attr_name
))
):
raise
DGLError
(
"Failed to find the pre-specified edge IDs in the edge features "
"of the NetworkX graph with name {}"
.
format
(
edge_id_attr_name
)
)
# Get the source and destination node sets
# Get the source and destination node sets
top_nodes
=
set
()
top_nodes
=
set
()
bottom_nodes
=
set
()
bottom_nodes
=
set
()
for
n
,
ndata
in
nx_graph
.
nodes
(
data
=
True
):
for
n
,
ndata
in
nx_graph
.
nodes
(
data
=
True
):
if
'bipartite'
not
in
ndata
:
if
"bipartite"
not
in
ndata
:
raise
DGLError
(
'Expect the node {} to have attribute bipartite'
.
format
(
n
))
raise
DGLError
(
if
ndata
[
'bipartite'
]
==
0
:
"Expect the node {} to have attribute bipartite"
.
format
(
n
)
)
if
ndata
[
"bipartite"
]
==
0
:
top_nodes
.
add
(
n
)
top_nodes
.
add
(
n
)
elif
ndata
[
'
bipartite
'
]
==
1
:
elif
ndata
[
"
bipartite
"
]
==
1
:
bottom_nodes
.
add
(
n
)
bottom_nodes
.
add
(
n
)
else
:
else
:
raise
ValueError
(
'Expect the bipartite attribute of the node {} to be 0 or 1, '
raise
ValueError
(
'got {}'
.
format
(
n
,
ndata
[
'bipartite'
]))
"Expect the bipartite attribute of the node {} to be 0 or 1, "
"got {}"
.
format
(
n
,
ndata
[
"bipartite"
])
)
# Separately relabel the source and destination nodes.
# Separately relabel the source and destination nodes.
top_nodes
=
sorted
(
top_nodes
)
top_nodes
=
sorted
(
top_nodes
)
bottom_nodes
=
sorted
(
bottom_nodes
)
bottom_nodes
=
sorted
(
bottom_nodes
)
top_map
=
{
n
:
i
for
i
,
n
in
enumerate
(
top_nodes
)}
top_map
=
{
n
:
i
for
i
,
n
in
enumerate
(
top_nodes
)}
bottom_map
=
{
n
:
i
for
i
,
n
in
enumerate
(
bottom_nodes
)}
bottom_map
=
{
n
:
i
for
i
,
n
in
enumerate
(
bottom_nodes
)}
# Get the node tensors and the number of nodes
# Get the node tensors and the number of nodes
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
(
sparse_fmt
,
arrays
),
urange
,
vrange
=
utils
.
graphdata2tensors
(
nx_graph
,
idtype
,
bipartite
=
True
,
nx_graph
,
idtype
,
bipartite
=
True
,
edge_id_attr_name
=
edge_id_attr_name
,
edge_id_attr_name
=
edge_id_attr_name
,
top_map
=
top_map
,
bottom_map
=
bottom_map
)
top_map
=
top_map
,
bottom_map
=
bottom_map
,
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
g
=
create_from_edges
(
sparse_fmt
,
arrays
,
utype
,
etype
,
vtype
,
urange
,
vrange
)
# nx_graph.edges(data=True) returns src, dst, attr_dict
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id
=
nx_graph
.
number_of_edges
()
>
0
and
edge_id_attr_name
is
not
None
has_edge_id
=
(
nx_graph
.
number_of_edges
()
>
0
and
edge_id_attr_name
is
not
None
)
# handle features
# handle features
# copy attributes
# copy attributes
...
@@ -1485,11 +1635,14 @@ def bipartite_from_networkx(nx_graph,
...
@@ -1485,11 +1635,14 @@ def bipartite_from_networkx(nx_graph,
for
attr
in
e_attrs
:
for
attr
in
e_attrs
:
for
val
in
attr_dict
[
attr
]:
for
val
in
attr_dict
[
attr
]:
if
val
is
None
:
if
val
is
None
:
raise
DGLError
(
'Not all edges have attribute {}.'
.
format
(
attr
))
raise
DGLError
(
"Not all edges have attribute {}."
.
format
(
attr
)
)
g
.
edata
[
attr
]
=
F
.
copy_to
(
_batcher
(
attr_dict
[
attr
]),
g
.
device
)
g
.
edata
[
attr
]
=
F
.
copy_to
(
_batcher
(
attr_dict
[
attr
]),
g
.
device
)
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
to_networkx
(
g
,
node_attrs
=
None
,
edge_attrs
=
None
):
def
to_networkx
(
g
,
node_attrs
=
None
,
edge_attrs
=
None
):
"""Convert a homogeneous graph to a NetworkX graph and return.
"""Convert a homogeneous graph to a NetworkX graph and return.
...
@@ -1537,9 +1690,11 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
...
@@ -1537,9 +1690,11 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})])
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})])
"""
"""
if
g
.
device
!=
F
.
cpu
():
if
g
.
device
!=
F
.
cpu
():
raise
DGLError
(
'Cannot convert a CUDA graph to networkx. Call g.cpu() first.'
)
raise
DGLError
(
"Cannot convert a CUDA graph to networkx. Call g.cpu() first."
)
if
not
g
.
is_homogeneous
:
if
not
g
.
is_homogeneous
:
raise
DGLError
(
'
dgl.to_networkx only supports homogeneous graphs.
'
)
raise
DGLError
(
"
dgl.to_networkx only supports homogeneous graphs.
"
)
src
,
dst
=
g
.
edges
()
src
,
dst
=
g
.
edges
()
src
=
F
.
asnumpy
(
src
)
src
=
F
.
asnumpy
(
src
)
dst
=
F
.
asnumpy
(
dst
)
dst
=
F
.
asnumpy
(
dst
)
...
@@ -1552,16 +1707,22 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
...
@@ -1552,16 +1707,22 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
if
node_attrs
is
not
None
:
if
node_attrs
is
not
None
:
for
nid
,
attr
in
nx_graph
.
nodes
(
data
=
True
):
for
nid
,
attr
in
nx_graph
.
nodes
(
data
=
True
):
feat_dict
=
g
.
_get_n_repr
(
0
,
nid
)
feat_dict
=
g
.
_get_n_repr
(
0
,
nid
)
attr
.
update
({
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
node_attrs
})
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
node_attrs
}
)
if
edge_attrs
is
not
None
:
if
edge_attrs
is
not
None
:
for
_
,
_
,
attr
in
nx_graph
.
edges
(
data
=
True
):
for
_
,
_
,
attr
in
nx_graph
.
edges
(
data
=
True
):
eid
=
attr
[
'
id
'
]
eid
=
attr
[
"
id
"
]
feat_dict
=
g
.
_get_e_repr
(
0
,
eid
)
feat_dict
=
g
.
_get_e_repr
(
0
,
eid
)
attr
.
update
({
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
edge_attrs
})
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
edge_attrs
}
)
return
nx_graph
return
nx_graph
DGLGraph
.
to_networkx
=
to_networkx
DGLGraph
.
to_networkx
=
to_networkx
def
to_cugraph
(
g
):
def
to_cugraph
(
g
):
"""Convert a DGL graph to a :class:`cugraph.Graph` and return.
"""Convert a DGL graph to a :class:`cugraph.Graph` and return.
...
@@ -1595,30 +1756,36 @@ def to_cugraph(g):
...
@@ -1595,30 +1756,36 @@ def to_cugraph(g):
1 1 1
1 1 1
"""
"""
if
g
.
device
.
type
!=
'cuda'
:
if
g
.
device
.
type
!=
"cuda"
:
raise
DGLError
(
f
"Cannot convert a
{
g
.
device
.
type
}
graph to cugraph."
+
raise
DGLError
(
"Call g.to('cuda') first."
)
f
"Cannot convert a
{
g
.
device
.
type
}
graph to cugraph."
+
"Call g.to('cuda') first."
)
if
not
g
.
is_homogeneous
:
if
not
g
.
is_homogeneous
:
raise
DGLError
(
"dgl.to_cugraph only supports homogeneous graphs."
)
raise
DGLError
(
"dgl.to_cugraph only supports homogeneous graphs."
)
try
:
try
:
import
cugraph
import
cudf
import
cudf
import
cugraph
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
raise
ModuleNotFoundError
(
"to_cugraph requires cugraph which could not be imported"
)
raise
ModuleNotFoundError
(
"to_cugraph requires cugraph which could not be imported"
)
edgelist
=
g
.
edges
()
edgelist
=
g
.
edges
()
src_ser
=
cudf
.
from_dlpack
(
F
.
zerocopy_to_dlpack
(
edgelist
[
0
]))
src_ser
=
cudf
.
from_dlpack
(
F
.
zerocopy_to_dlpack
(
edgelist
[
0
]))
dst_ser
=
cudf
.
from_dlpack
(
F
.
zerocopy_to_dlpack
(
edgelist
[
1
]))
dst_ser
=
cudf
.
from_dlpack
(
F
.
zerocopy_to_dlpack
(
edgelist
[
1
]))
cudf_data
=
cudf
.
DataFrame
({
'
source
'
:
src_ser
,
'
destination
'
:
dst_ser
})
cudf_data
=
cudf
.
DataFrame
({
"
source
"
:
src_ser
,
"
destination
"
:
dst_ser
})
g_cugraph
=
cugraph
.
Graph
(
directed
=
True
)
g_cugraph
=
cugraph
.
Graph
(
directed
=
True
)
g_cugraph
.
from_cudf_edgelist
(
cudf_data
,
g_cugraph
.
from_cudf_edgelist
(
source
=
'
source
'
,
cudf_data
,
source
=
"
source
"
,
destination
=
"destination"
destination
=
'destination'
)
)
return
g_cugraph
return
g_cugraph
DGLGraph
.
to_cugraph
=
to_cugraph
DGLGraph
.
to_cugraph
=
to_cugraph
def
from_cugraph
(
cugraph_graph
):
def
from_cugraph
(
cugraph_graph
):
"""Create a graph from a :class:`cugraph.Graph` object.
"""Create a graph from a :class:`cugraph.Graph` object.
...
@@ -1660,21 +1827,29 @@ def from_cugraph(cugraph_graph):
...
@@ -1660,21 +1827,29 @@ def from_cugraph(cugraph_graph):
cugraph_graph
=
cugraph_graph
.
to_directed
()
cugraph_graph
=
cugraph_graph
.
to_directed
()
edges
=
cugraph_graph
.
edges
()
edges
=
cugraph_graph
.
edges
()
src_t
=
F
.
zerocopy_from_dlpack
(
edges
[
'
src
'
].
to_dlpack
())
src_t
=
F
.
zerocopy_from_dlpack
(
edges
[
"
src
"
].
to_dlpack
())
dst_t
=
F
.
zerocopy_from_dlpack
(
edges
[
'
dst
'
].
to_dlpack
())
dst_t
=
F
.
zerocopy_from_dlpack
(
edges
[
"
dst
"
].
to_dlpack
())
g
=
graph
((
src_t
,
dst_t
))
g
=
graph
((
src_t
,
dst_t
))
return
g
return
g
############################################################
############################################################
# Internal APIs
# Internal APIs
############################################################
############################################################
def
create_from_edges
(
sparse_fmt
,
arrays
,
utype
,
etype
,
vtype
,
def
create_from_edges
(
urange
,
vrange
,
sparse_fmt
,
row_sorted
=
False
,
arrays
,
col_sorted
=
False
):
utype
,
etype
,
vtype
,
urange
,
vrange
,
row_sorted
=
False
,
col_sorted
=
False
,
):
"""Internal function to create a graph from incident nodes with types.
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
utype could be equal to vtype
...
@@ -1713,16 +1888,30 @@ def create_from_edges(sparse_fmt, arrays,
...
@@ -1713,16 +1888,30 @@ def create_from_edges(sparse_fmt, arrays,
else
:
else
:
num_ntypes
=
2
num_ntypes
=
2
if
sparse_fmt
==
'
coo
'
:
if
sparse_fmt
==
"
coo
"
:
u
,
v
=
arrays
u
,
v
=
arrays
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
num_ntypes
,
urange
,
vrange
,
u
,
v
,
[
'coo'
,
'csr'
,
'csc'
],
num_ntypes
,
row_sorted
,
col_sorted
)
urange
,
else
:
# 'csr' or 'csc'
vrange
,
u
,
v
,
[
"coo"
,
"csr"
,
"csc"
],
row_sorted
,
col_sorted
,
)
else
:
# 'csr' or 'csc'
indptr
,
indices
,
eids
=
arrays
indptr
,
indices
,
eids
=
arrays
hgidx
=
heterograph_index
.
create_unitgraph_from_csr
(
hgidx
=
heterograph_index
.
create_unitgraph_from_csr
(
num_ntypes
,
urange
,
vrange
,
indptr
,
indices
,
eids
,
[
'coo'
,
'csr'
,
'csc'
],
num_ntypes
,
sparse_fmt
==
'csc'
)
urange
,
vrange
,
indptr
,
indices
,
eids
,
[
"coo"
,
"csr"
,
"csc"
],
sparse_fmt
==
"csc"
,
)
if
utype
==
vtype
:
if
utype
==
vtype
:
return
DGLGraph
(
hgidx
,
[
utype
],
[
etype
])
return
DGLGraph
(
hgidx
,
[
utype
],
[
etype
])
else
:
else
:
...
...
python/dgl/core.py
View file @
a566b60b
...
@@ -2,10 +2,8 @@
...
@@ -2,10 +2,8 @@
# pylint: disable=not-callable
# pylint: disable=not-callable
import
numpy
as
np
import
numpy
as
np
from
.
import
backend
as
F
from
.
import
backend
as
F
,
function
as
fn
,
ops
from
.
import
function
as
fn
from
.base
import
ALL
,
dgl_warning
,
DGLError
,
EID
,
is_all
,
NID
from
.
import
ops
from
.base
import
ALL
,
EID
,
NID
,
DGLError
,
dgl_warning
,
is_all
from
.frame
import
Frame
from
.frame
import
Frame
from
.udf
import
EdgeBatch
,
NodeBatch
from
.udf
import
EdgeBatch
,
NodeBatch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment