Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d1827488
Unverified
Commit
d1827488
authored
Feb 20, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 20, 2023
Browse files
autofix2 (#5333)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
3e5137fe
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1129 additions
and
560 deletions
+1129
-560
python/dgl/dataloading/base.py
python/dgl/dataloading/base.py
+125
-47
python/dgl/dataloading/dist_dataloader.py
python/dgl/dataloading/dist_dataloader.py
+107
-53
python/dgl/dataloading/labor_sampler.py
python/dgl/dataloading/labor_sampler.py
+4
-6
python/dgl/dataloading/neighbor_sampler.py
python/dgl/dataloading/neighbor_sampler.py
+36
-14
python/dgl/dataloading/shadow.py
python/dgl/dataloading/shadow.py
+26
-8
python/dgl/function/message.py
python/dgl/function/message.py
+20
-10
python/dgl/generators.py
python/dgl/generators.py
+1
-2
python/dgl/geometry/capi.py
python/dgl/geometry/capi.py
+1
-2
python/dgl/graph_index.py
python/dgl/graph_index.py
+2
-3
python/dgl/heterograph.py
python/dgl/heterograph.py
+805
-412
python/dgl/heterograph_index.py
python/dgl/heterograph_index.py
+2
-3
No files found.
python/dgl/dataloading/base.py
View file @
d1827488
"""Base classes and functionalities for dataloaders"""
from
collections.abc
import
Mapping
import
inspect
from
..base
import
NID
,
EID
from
..convert
import
heterograph
from
collections.abc
import
Mapping
from
..
import
backend
as
F
from
..transforms
import
compact_graphs
from
..base
import
EID
,
NID
from
..convert
import
heterograph
from
..frame
import
LazyFeature
from
..utils
import
recursive_apply
,
context_of
from
..transforms
import
compact_graphs
from
..utils
import
context_of
,
recursive_apply
def
_set_lazy_features
(
x
,
xdata
,
feature_names
):
if
feature_names
is
None
:
...
...
@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for
type_
,
names
in
feature_names
.
items
():
x
[
type_
].
data
.
update
({
k
:
LazyFeature
(
k
)
for
k
in
names
})
def
set_node_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
...
...
@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
"""
return
_set_lazy_features
(
g
.
nodes
,
g
.
ndata
,
feature_names
)
def
set_edge_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
...
...
@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
"""
return
_set_lazy_features
(
g
.
edges
,
g
.
edata
,
feature_names
)
def
set_src_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
...
...
@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
"""
return
_set_lazy_features
(
g
.
srcnodes
,
g
.
srcdata
,
feature_names
)
def
set_dst_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
...
...
@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
"""
return
_set_lazy_features
(
g
.
dstnodes
,
g
.
dstdata
,
feature_names
)
class
Sampler
(
object
):
"""Base class for graph samplers.
...
...
@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices):
return g.subgraph(indices)
"""
def
sample
(
self
,
g
,
indices
):
"""Abstract sample method.
...
...
@@ -183,6 +191,7 @@ class Sampler(object):
"""
raise
NotImplementedError
class
BlockSampler
(
Sampler
):
"""Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs).
...
...
@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
"""
def
__init__
(
self
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
):
def
__init__
(
self
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
()
self
.
prefetch_node_feats
=
prefetch_node_feats
or
[]
self
.
prefetch_labels
=
prefetch_labels
or
[]
...
...
@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features
(
block
,
self
.
prefetch_edge_feats
)
return
input_nodes
,
output_nodes
,
blocks
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes."""
result
=
self
.
sample_blocks
(
g
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
self
.
assign_lazy_features
(
result
)
...
...
@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
for
k
,
v
in
eids
.
items
()}
for
k
,
v
in
eids
.
items
()
}
else
:
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
return
exclude_eids
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
reverse_etype_map
=
{
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
for
k
,
v
in
reverse_etype_map
.
items
()}
exclude_eids
.
update
({
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()})
for
k
,
v
in
reverse_etype_map
.
items
()
}
exclude_eids
.
update
(
{
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()}
)
return
exclude_eids
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
if
exclude_mode
is
None
:
return
None
elif
callable
(
exclude_mode
):
return
exclude_mode
(
eids
)
elif
F
.
is_tensor
(
exclude_mode
)
or
(
isinstance
(
exclude_mode
,
Mapping
)
and
all
(
F
.
is_tensor
(
v
)
for
v
in
exclude_mode
.
values
())):
isinstance
(
exclude_mode
,
Mapping
)
and
all
(
F
.
is_tensor
(
v
)
for
v
in
exclude_mode
.
values
())
):
return
exclude_mode
elif
exclude_mode
==
'
self
'
:
elif
exclude_mode
==
"
self
"
:
return
eids
elif
exclude_mode
==
'reverse_id'
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
'reverse_eid_map'
])
elif
exclude_mode
==
'reverse_types'
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
'reverse_etype_map'
])
elif
exclude_mode
==
"reverse_id"
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
"reverse_eid_map"
]
)
elif
exclude_mode
==
"reverse_types"
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
"reverse_etype_map"
]
)
else
:
raise
ValueError
(
'
unsupported mode {}
'
.
format
(
exclude_mode
))
raise
ValueError
(
"
unsupported mode {}
"
.
format
(
exclude_mode
))
def
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
output_device
=
None
):
def
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
output_device
=
None
,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters
...
...
@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude
,
seed_edges
,
reverse_eid_map
=
reverse_eids
,
reverse_etype_map
=
reverse_etypes
)
reverse_etype_map
=
reverse_etypes
,
)
if
exclude_eids
is
not
None
and
output_device
is
not
None
:
exclude_eids
=
recursive_apply
(
exclude_eids
,
lambda
x
:
F
.
copy_to
(
x
,
output_device
))
exclude_eids
=
recursive_apply
(
exclude_eids
,
lambda
x
:
F
.
copy_to
(
x
,
output_device
)
)
return
exclude_eids
class
EdgePredictionSampler
(
Sampler
):
"""Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction.
...
...
@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
--------
as_edge_prediction_sampler
"""
def
__init__
(
self
,
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
):
def
__init__
(
self
,
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
,
):
super
().
__init__
()
# Check if the sampler's sample method has an optional third argument.
argspec
=
inspect
.
getfullargspec
(
sampler
.
sample
)
if
len
(
argspec
.
args
)
<
4
:
# ['self', 'g', 'indices', 'exclude_eids']
if
len
(
argspec
.
args
)
<
4
:
# ['self', 'g', 'indices', 'exclude_eids']
raise
TypeError
(
"This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method."
)
"optional third argument for edge IDs to exclude in its sample() method."
)
self
.
reverse_eids
=
reverse_eids
self
.
reverse_etypes
=
reverse_etypes
self
.
exclude
=
exclude
...
...
@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def
_build_neg_graph
(
self
,
g
,
seed_edges
):
neg_srcdst
=
self
.
negative_sampler
(
g
,
seed_edges
)
if
not
isinstance
(
neg_srcdst
,
Mapping
):
assert
len
(
g
.
canonical_etypes
)
==
1
,
\
'graph has multiple or no edge types; '
\
'please return a dict in negative sampler.'
assert
len
(
g
.
canonical_etypes
)
==
1
,
(
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst
=
{
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
ctx
=
context_of
(
seed_edges
)
if
seed_edges
is
not
None
else
g
.
device
neg_edges
=
{
etype
:
neg_srcdst
.
get
(
etype
,
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
)))
for
etype
in
g
.
canonical_etypes
}
etype
:
neg_srcdst
.
get
(
etype
,
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
),
)
for
etype
in
g
.
canonical_etypes
}
neg_pair_graph
=
heterograph
(
neg_edges
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
})
neg_edges
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
}
)
return
neg_pair_graph
def
assign_lazy_features
(
self
,
result
):
...
...
@@ -390,7 +445,7 @@ class EdgePredictionSampler(Sampler):
# In-place updates
return
result
def
sample
(
self
,
g
,
seed_edges
):
# pylint: disable=arguments-differ
def
sample
(
self
,
g
,
seed_edges
):
# pylint: disable=arguments-differ
"""Samples a list of blocks, as well as a subgraph containing the sampled
edges from the original graph.
...
...
@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges.
"""
if
isinstance
(
seed_edges
,
Mapping
):
seed_edges
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
seed_edges
.
items
()}
seed_edges
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
seed_edges
.
items
()
}
exclude
=
self
.
exclude
pair_graph
=
g
.
edge_subgraph
(
seed_edges
,
relabel_nodes
=
False
,
output_device
=
self
.
output_device
)
seed_edges
,
relabel_nodes
=
False
,
output_device
=
self
.
output_device
)
eids
=
pair_graph
.
edata
[
EID
]
if
self
.
negative_sampler
is
not
None
:
...
...
@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes
=
pair_graph
.
ndata
[
NID
]
exclude_eids
=
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
self
.
reverse_eids
,
self
.
reverse_etypes
,
self
.
output_device
)
input_nodes
,
_
,
blocks
=
self
.
sampler
.
sample
(
g
,
seed_nodes
,
exclude_eids
)
g
,
seed_edges
,
exclude
,
self
.
reverse_eids
,
self
.
reverse_etypes
,
self
.
output_device
,
)
input_nodes
,
_
,
blocks
=
self
.
sampler
.
sample
(
g
,
seed_nodes
,
exclude_eids
)
if
self
.
negative_sampler
is
None
:
return
self
.
assign_lazy_features
((
input_nodes
,
pair_graph
,
blocks
))
else
:
return
self
.
assign_lazy_features
((
input_nodes
,
pair_graph
,
neg_graph
,
blocks
))
return
self
.
assign_lazy_features
(
(
input_nodes
,
pair_graph
,
neg_graph
,
blocks
)
)
def
as_edge_prediction_sampler
(
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
):
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
,
):
"""Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to
...
...
@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
"""
return
EdgePredictionSampler
(
sampler
,
exclude
=
exclude
,
reverse_eids
=
reverse_eids
,
reverse_etypes
=
reverse_etypes
,
negative_sampler
=
negative_sampler
,
prefetch_labels
=
prefetch_labels
)
sampler
,
exclude
=
exclude
,
reverse_eids
=
reverse_eids
,
reverse_etypes
=
reverse_etypes
,
negative_sampler
=
negative_sampler
,
prefetch_labels
=
prefetch_labels
,
)
python/dgl/dataloading/dist_dataloader.py
View file @
d1827488
"""Distributed dataloaders.
"""
import
inspect
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
collections.abc
import
Mapping
from
abc
import
ABC
,
abstractproperty
,
abstractmethod
from
..
import
transforms
from
..base
import
NID
,
EID
from
..
import
backend
as
F
from
..
import
utils
from
..
import
backend
as
F
,
transforms
,
utils
from
..base
import
EID
,
NID
from
..convert
import
heterograph
from
..distributed
import
DistDataLoader
...
...
@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
for
k
,
v
in
eids
.
items
()}
for
k
,
v
in
eids
.
items
()
}
else
:
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
return
exclude_eids
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
reverse_etype_map
=
{
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
for
k
,
v
in
reverse_etype_map
.
items
()}
exclude_eids
.
update
({
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()})
for
k
,
v
in
reverse_etype_map
.
items
()
}
exclude_eids
.
update
(
{
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()}
)
return
exclude_eids
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
...
...
@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""
if
exclude_mode
is
None
:
return
None
elif
exclude_mode
==
'
self
'
:
elif
exclude_mode
==
"
self
"
:
return
eids
elif
exclude_mode
==
'reverse_id'
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
'reverse_eid_map'
])
elif
exclude_mode
==
'reverse_types'
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
'reverse_etype_map'
])
elif
exclude_mode
==
"reverse_id"
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
"reverse_eid_map"
]
)
elif
exclude_mode
==
"reverse_types"
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
"reverse_etype_map"
]
)
else
:
raise
ValueError
(
'
unsupported mode {}
'
.
format
(
exclude_mode
))
raise
ValueError
(
"
unsupported mode {}
"
.
format
(
exclude_mode
))
class
Collator
(
ABC
):
...
...
@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
@
abstractproperty
def
dataset
(
self
):
"""Returns the dataset object of the collator."""
...
...
@@ -122,6 +132,7 @@ class Collator(ABC):
"""
raise
NotImplementedError
class
NodeCollator
(
Collator
):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling.
...
...
@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def
__init__
(
self
,
g
,
nids
,
graph_sampler
):
self
.
g
=
g
if
not
isinstance
(
nids
,
Mapping
):
assert
len
(
g
.
ntypes
)
==
1
,
\
"nids should be a dict of node type and ids for graph with multiple node types"
assert
(
len
(
g
.
ntypes
)
==
1
),
"nids should be a dict of node type and ids for graph with multiple node types"
self
.
graph_sampler
=
graph_sampler
self
.
nids
=
utils
.
prepare_tensor_or_dict
(
g
,
nids
,
'
nids
'
)
self
.
nids
=
utils
.
prepare_tensor_or_dict
(
g
,
nids
,
"
nids
"
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
nids
)
@
property
...
...
@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g
,
items
,
"
items
"
)
input_nodes
,
output_nodes
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g
,
items
)
input_nodes
,
output_nodes
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g
,
items
)
return
input_nodes
,
output_nodes
,
blocks
class
EdgeCollator
(
Collator
):
"""DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph
...
...
@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
g_sampling
=
None
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
):
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
g_sampling
=
None
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
):
self
.
g
=
g
if
not
isinstance
(
eids
,
Mapping
):
assert
len
(
g
.
etypes
)
==
1
,
\
"eids should be a dict of etype and ids for graph with multiple etypes"
assert
(
len
(
g
.
etypes
)
==
1
),
"eids should be a dict of etype and ids for graph with multiple etypes"
self
.
graph_sampler
=
graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in
...
...
@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self
.
reverse_etypes
=
reverse_etypes
self
.
negative_sampler
=
negative_sampler
self
.
eids
=
utils
.
prepare_tensor_or_dict
(
g
,
eids
,
'
eids
'
)
self
.
eids
=
utils
.
prepare_tensor_or_dict
(
g
,
eids
,
"
eids
"
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
eids
)
@
property
...
...
@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
"
items
"
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
)
seed_nodes
=
pair_graph
.
ndata
[
NID
]
...
...
@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self
.
exclude
,
items
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_etype_map
=
self
.
reverse_etypes
)
reverse_etype_map
=
self
.
reverse_etypes
,
)
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
input_nodes
,
pair_graph
,
blocks
...
...
@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
"
items
"
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
,
relabel_nodes
=
False
)
induced_edges
=
pair_graph
.
edata
[
EID
]
neg_srcdst
=
self
.
negative_sampler
(
self
.
g
,
items
)
if
not
isinstance
(
neg_srcdst
,
Mapping
):
assert
len
(
self
.
g
.
etypes
)
==
1
,
\
'graph has multiple or no edge types; '
\
'please return a dict in negative sampler.'
assert
len
(
self
.
g
.
etypes
)
==
1
,
(
"graph has multiple or no edge types; "
"please return a dict in negative sampler."
)
neg_srcdst
=
{
self
.
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
# Get dtype from a tuple of tensors
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
ctx
=
F
.
context
(
pair_graph
)
neg_edges
=
{
etype
:
neg_srcdst
.
get
(
etype
,
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
)))
for
etype
in
self
.
g
.
canonical_etypes
}
etype
:
neg_srcdst
.
get
(
etype
,
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
),
)
for
etype
in
self
.
g
.
canonical_etypes
}
neg_pair_graph
=
heterograph
(
neg_edges
,
{
ntype
:
self
.
g
.
number_of_nodes
(
ntype
)
for
ntype
in
self
.
g
.
ntypes
})
neg_edges
,
{
ntype
:
self
.
g
.
number_of_nodes
(
ntype
)
for
ntype
in
self
.
g
.
ntypes
},
)
pair_graph
,
neg_pair_graph
=
transforms
.
compact_graphs
([
pair_graph
,
neg_pair_graph
])
pair_graph
,
neg_pair_graph
=
transforms
.
compact_graphs
(
[
pair_graph
,
neg_pair_graph
]
)
pair_graph
.
edata
[
EID
]
=
induced_edges
seed_nodes
=
pair_graph
.
ndata
[
NID
]
...
...
@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self
.
exclude
,
items
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_etype_map
=
self
.
reverse_etypes
)
reverse_etype_map
=
self
.
reverse_etypes
,
)
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
...
...
@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def
_remove_kwargs_dist
(
kwargs
):
if
'
num_workers
'
in
kwargs
:
del
kwargs
[
'
num_workers
'
]
if
'
pin_memory
'
in
kwargs
:
del
kwargs
[
'
pin_memory
'
]
print
(
'
Distributed DataLoaders do not support pin_memory.
'
)
if
"
num_workers
"
in
kwargs
:
del
kwargs
[
"
num_workers
"
]
if
"
pin_memory
"
in
kwargs
:
del
kwargs
[
"
pin_memory
"
]
print
(
"
Distributed DataLoaders do not support pin_memory.
"
)
return
kwargs
class
DistNodeDataLoader
(
DistDataLoader
):
"""Sampled graph data loader over nodes for distributed graph storage.
...
...
@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def
__init__
(
self
,
g
,
nids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
collator_kwargs
=
{}
dataloader_kwargs
=
{}
...
...
@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs
[
k
]
=
v
if
device
is
None
:
# for the distributed case default to the CPU
device
=
'cpu'
assert
device
==
'cpu'
,
'Only cpu is supported in the case of a DistGraph.'
device
=
"cpu"
assert
(
device
==
"cpu"
),
"Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self
.
collator
=
NodeCollator
(
g
,
nids
,
graph_sampler
,
**
collator_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
self
.
device
=
device
class
DistEdgeDataLoader
(
DistDataLoader
):
"""Sampled graph data loader over edges for distributed graph storage.
...
...
@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
--------
dgl.dataloading.DataLoader
"""
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
collator_kwargs
=
{}
dataloader_kwargs
=
{}
...
...
@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if
device
is
None
:
# for the distributed case default to the CPU
device
=
'cpu'
assert
device
==
'cpu'
,
'Only cpu is supported in the case of a DistGraph.'
device
=
"cpu"
assert
(
device
==
"cpu"
),
"Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self
.
collator
=
EdgeCollator
(
g
,
eids
,
graph_sampler
,
**
collator_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
self
.
device
=
device
python/dgl/dataloading/labor_sampler.py
View file @
d1827488
...
...
@@ -17,11 +17,11 @@
#
"""Data loading components for labor sampling"""
from
..base
import
NID
,
EID
from
..
import
backend
as
F
from
..base
import
EID
,
NID
from
..random
import
choice
from
..transforms
import
to_block
from
.base
import
BlockSampler
from
..random
import
choice
from
..
import
backend
as
F
class
LaborSampler
(
BlockSampler
):
...
...
@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
)
block
.
edata
[
EID
]
=
eid
if
len
(
g
.
canonical_etypes
)
>
1
:
for
etype
,
importance
in
zip
(
g
.
canonical_etypes
,
importances
):
for
etype
,
importance
in
zip
(
g
.
canonical_etypes
,
importances
):
if
importance
.
shape
[
0
]
==
block
.
num_edges
(
etype
):
block
.
edata
[
"edge_weights"
][
etype
]
=
importance
elif
importances
[
0
].
shape
[
0
]
==
block
.
num_edges
():
...
...
python/dgl/dataloading/neighbor_sampler.py
View file @
d1827488
"""Data loading components for neighbor sampling"""
from
..base
import
N
ID
,
E
ID
from
..base
import
E
ID
,
N
ID
from
..transforms
import
to_block
from
.base
import
BlockSampler
class
NeighborSampler
(
BlockSampler
):
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
...
...
@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def
__init__
(
self
,
fanouts
,
edge_dir
=
'in'
,
prob
=
None
,
mask
=
None
,
replace
=
False
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
):
super
().
__init__
(
prefetch_node_feats
=
prefetch_node_feats
,
prefetch_labels
=
prefetch_labels
,
prefetch_edge_feats
=
prefetch_edge_feats
,
output_device
=
output_device
)
def
__init__
(
self
,
fanouts
,
edge_dir
=
"in"
,
prob
=
None
,
mask
=
None
,
replace
=
False
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
(
prefetch_node_feats
=
prefetch_node_feats
,
prefetch_labels
=
prefetch_labels
,
prefetch_edge_feats
=
prefetch_edge_feats
,
output_device
=
output_device
,
)
self
.
fanouts
=
fanouts
self
.
edge_dir
=
edge_dir
if
mask
is
not
None
and
prob
is
not
None
:
raise
ValueError
(
'Mask and probability arguments are mutually exclusive. '
'Consider multiplying the probability with the mask '
'to achieve the same goal.'
)
"Mask and probability arguments are mutually exclusive. "
"Consider multiplying the probability with the mask "
"to achieve the same goal."
)
self
.
prob
=
prob
or
mask
self
.
replace
=
replace
...
...
@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks
=
[]
for
fanout
in
reversed
(
self
.
fanouts
):
frontier
=
g
.
sample_neighbors
(
seed_nodes
,
fanout
,
edge_dir
=
self
.
edge_dir
,
prob
=
self
.
prob
,
replace
=
self
.
replace
,
output_device
=
self
.
output_device
,
exclude_edges
=
exclude_eids
)
seed_nodes
,
fanout
,
edge_dir
=
self
.
edge_dir
,
prob
=
self
.
prob
,
replace
=
self
.
replace
,
output_device
=
self
.
output_device
,
exclude_edges
=
exclude_eids
,
)
eid
=
frontier
.
edata
[
EID
]
block
=
to_block
(
frontier
,
seed_nodes
)
block
.
edata
[
EID
]
=
eid
...
...
@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return
seed_nodes
,
output_nodes
,
blocks
MultiLayerNeighborSampler
=
NeighborSampler
class
MultiLayerFullNeighborSampler
(
NeighborSampler
):
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
...
...
@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def
__init__
(
self
,
num_layers
,
**
kwargs
):
super
().
__init__
([
-
1
]
*
num_layers
,
**
kwargs
)
python/dgl/dataloading/shadow.py
View file @
d1827488
"""ShaDow-GNN subgraph samplers."""
from
..sampling.utils
import
EidExcluder
from
..
import
transforms
from
..base
import
NID
from
.base
import
set_node_lazy_features
,
set_edge_lazy_features
,
Sampler
from
..sampling.utils
import
EidExcluder
from
.base
import
Sampler
,
set_edge_lazy_features
,
set_node_lazy_features
class
ShaDowKHopSampler
(
Sampler
):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
...
...
@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
"""
def
__init__
(
self
,
fanouts
,
replace
=
False
,
prob
=
None
,
prefetch_node_feats
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
):
def
__init__
(
self
,
fanouts
,
replace
=
False
,
prob
=
None
,
prefetch_node_feats
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
()
self
.
fanouts
=
fanouts
self
.
replace
=
replace
...
...
@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self
.
prefetch_edge_feats
=
prefetch_edge_feats
self
.
output_device
=
output_device
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
"""Sampling function.
Parameters
...
...
@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes
=
seed_nodes
for
fanout
in
reversed
(
self
.
fanouts
):
frontier
=
g
.
sample_neighbors
(
seed_nodes
,
fanout
,
output_device
=
self
.
output_device
,
replace
=
self
.
replace
,
prob
=
self
.
prob
,
exclude_edges
=
exclude_eids
)
seed_nodes
,
fanout
,
output_device
=
self
.
output_device
,
replace
=
self
.
replace
,
prob
=
self
.
prob
,
exclude_edges
=
exclude_eids
,
)
block
=
transforms
.
to_block
(
frontier
,
seed_nodes
)
seed_nodes
=
block
.
srcdata
[
NID
]
subg
=
g
.
subgraph
(
seed_nodes
,
relabel_nodes
=
True
,
output_device
=
self
.
output_device
)
subg
=
g
.
subgraph
(
seed_nodes
,
relabel_nodes
=
True
,
output_device
=
self
.
output_device
)
if
exclude_eids
is
not
None
:
subg
=
EidExcluder
(
exclude_eids
)(
subg
)
...
...
python/dgl/function/message.py
View file @
d1827488
...
...
@@ -7,8 +7,7 @@ from itertools import product
from
.base
import
BuiltinFunction
,
TargetCode
__all__
=
[
"copy_u"
,
"copy_e"
,
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
__all__
=
[
"copy_u"
,
"copy_e"
,
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
class
MessageFunction
(
BuiltinFunction
):
...
...
@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
--------
u_mul_e
"""
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
self
.
binary_op
=
binary_op
self
.
lhs
=
lhs
...
...
@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
--------
copy_u
"""
def
__init__
(
self
,
target
,
in_field
,
out_field
):
self
.
target
=
target
self
.
in_field
=
in_field
...
...
@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
--------
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
"""
.
format
(
binary_op
,
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
name
)
"""
.
format
(
binary_op
,
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
name
,
)
def
func
(
lhs_field
,
rhs_field
,
out
):
return
BinaryMessageFunction
(
binary_op
,
_TARGET_MAP
[
lhs
],
_TARGET_MAP
[
rhs
],
lhs_field
,
rhs_field
,
out
)
binary_op
,
_TARGET_MAP
[
lhs
],
_TARGET_MAP
[
rhs
],
lhs_field
,
rhs_field
,
out
,
)
func
.
__name__
=
name
func
.
__doc__
=
docstring
return
func
...
...
@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr
(
sys
.
modules
[
__name__
],
func
.
__name__
,
func
)
__all__
.
append
(
func
.
__name__
)
_register_builtin_message_func
()
python/dgl/generators.py
View file @
d1827488
"""Module for various graph generator functions."""
from
.
import
backend
as
F
from
.
import
convert
,
random
from
.
import
backend
as
F
,
convert
,
random
__all__
=
[
"rand_graph"
,
"rand_bipartite"
]
...
...
python/dgl/geometry/capi.py
View file @
d1827488
"""Python interfaces to DGL farthest point sampler."""
import
numpy
as
np
from
..
import
backend
as
F
from
..
import
ndarray
as
nd
from
..
import
backend
as
F
,
ndarray
as
nd
from
.._ffi.base
import
DGLError
from
.._ffi.function
import
_init_api
...
...
python/dgl/graph_index.py
View file @
d1827488
...
...
@@ -5,11 +5,10 @@ import networkx as nx
import
numpy
as
np
import
scipy
from
.
import
backend
as
F
from
.
import
utils
from
.
import
backend
as
F
,
utils
from
._ffi.function
import
_init_api
from
._ffi.object
import
ObjectBase
,
register_object
from
.base
import
DGLError
,
dgl_warning
from
.base
import
dgl_warning
,
DGLError
class
BoolFlag
(
object
):
...
...
python/dgl/heterograph.py
View file @
d1827488
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections.abc
import
Mapping
,
Iterable
from
contextlib
import
contextmanager
import
copy
import
numbers
import
itertools
import
numbers
# pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
contextlib
import
contextmanager
import
networkx
as
nx
import
numpy
as
np
from
.
import
backend
as
F
,
core
,
graph_index
,
heterograph_index
,
utils
from
._ffi.function
import
_init_api
from
.ops
import
segment
from
.base
import
ALL
,
SLICE_FULL
,
NTYPE
,
NID
,
ETYPE
,
EID
,
is_all
,
DGLError
,
dgl_warning
from
.
import
core
from
.
import
graph_index
from
.
import
heterograph_index
from
.
import
utils
from
.
import
backend
as
F
from
.base
import
(
ALL
,
dgl_warning
,
DGLError
,
EID
,
ETYPE
,
is_all
,
NID
,
NTYPE
,
SLICE_FULL
,
)
from
.frame
import
Frame
from
.view
import
HeteroNodeView
,
HeteroNodeDataView
,
HeteroEdgeView
,
HeteroEdgeDataView
from
.ops
import
segment
from
.view
import
(
HeteroEdgeDataView
,
HeteroEdgeView
,
HeteroNodeDataView
,
HeteroNodeView
,
)
__all__
=
[
"DGLGraph"
,
"combine_names"
]
__all__
=
[
'DGLGraph'
,
'combine_names'
]
class
DGLGraph
(
object
):
"""Class for storing graph structure and node/edge feature data.
...
...
@@ -35,16 +50,19 @@ class DGLGraph(object):
Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its
usage.
"""
is_block
=
False
# pylint: disable=unused-argument, dangerous-default-value
def
__init__
(
self
,
gidx
=
[],
ntypes
=
[
'_N'
],
etypes
=
[
'_E'
],
node_frames
=
None
,
edge_frames
=
None
,
**
deprecate_kwargs
):
def
__init__
(
self
,
gidx
=
[],
ntypes
=
[
"_N"
],
etypes
=
[
"_E"
],
node_frames
=
None
,
edge_frames
=
None
,
**
deprecate_kwargs
):
"""Internal constructor for creating a DGLGraph.
Parameters
...
...
@@ -67,21 +85,42 @@ class DGLGraph(object):
of edge type i. (default: None)
"""
if
isinstance
(
gidx
,
DGLGraph
):
raise
DGLError
(
'The input is already a DGLGraph. No need to create it again.'
)
raise
DGLError
(
"The input is already a DGLGraph. No need to create it again."
)
if
not
isinstance
(
gidx
,
heterograph_index
.
HeteroGraphIndex
):
dgl_warning
(
'Recommend creating graphs by `dgl.graph(data)`'
' instead of `dgl.DGLGraph(data)`.'
)
(
sparse_fmt
,
arrays
),
num_src
,
num_dst
=
utils
.
graphdata2tensors
(
gidx
)
if
sparse_fmt
==
'coo'
:
dgl_warning
(
"Recommend creating graphs by `dgl.graph(data)`"
" instead of `dgl.DGLGraph(data)`."
)
(
sparse_fmt
,
arrays
),
num_src
,
num_dst
=
utils
.
graphdata2tensors
(
gidx
)
if
sparse_fmt
==
"coo"
:
gidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
[
'coo'
,
'csr'
,
'csc'
])
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
[
"coo"
,
"csr"
,
"csc"
],
)
else
:
gidx
=
heterograph_index
.
create_unitgraph_from_csr
(
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
arrays
[
2
],
[
'coo'
,
'csr'
,
'csc'
],
sparse_fmt
==
'csc'
)
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
arrays
[
2
],
[
"coo"
,
"csr"
,
"csc"
],
sparse_fmt
==
"csc"
,
)
if
len
(
deprecate_kwargs
)
!=
0
:
dgl_warning
(
'Keyword arguments {} are deprecated in v0.5, and can be safely'
' removed in all cases.'
.
format
(
list
(
deprecate_kwargs
.
keys
())))
dgl_warning
(
"Keyword arguments {} are deprecated in v0.5, and can be safely"
" removed in all cases."
.
format
(
list
(
deprecate_kwargs
.
keys
()))
)
self
.
_init
(
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
)
def
_init
(
self
,
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
):
...
...
@@ -94,39 +133,51 @@ class DGLGraph(object):
# Handle node types
if
isinstance
(
ntypes
,
tuple
):
if
len
(
ntypes
)
!=
2
:
errmsg
=
'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'
.
format
(
ntypes
)
errmsg
=
"Invalid input. Expect a pair (srctypes, dsttypes) but got {}"
.
format
(
ntypes
)
raise
TypeError
(
errmsg
)
if
not
self
.
_graph
.
is_metagraph_unibipartite
():
raise
ValueError
(
'Invalid input. The metagraph must be a uni-directional'
' bipartite graph.'
)
raise
ValueError
(
"Invalid input. The metagraph must be a uni-directional"
" bipartite graph."
)
self
.
_ntypes
=
ntypes
[
0
]
+
ntypes
[
1
]
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
ntypes
[
0
])}
self
.
_dsttypes_invmap
=
{
t
:
i
+
len
(
ntypes
[
0
])
for
i
,
t
in
enumerate
(
ntypes
[
1
])}
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
ntypes
[
0
])}
self
.
_dsttypes_invmap
=
{
t
:
i
+
len
(
ntypes
[
0
])
for
i
,
t
in
enumerate
(
ntypes
[
1
])
}
self
.
_is_unibipartite
=
True
if
len
(
ntypes
[
0
])
==
1
and
len
(
ntypes
[
1
])
==
1
and
len
(
etypes
)
==
1
:
self
.
_canonical_etypes
=
[(
ntypes
[
0
][
0
],
etypes
[
0
],
ntypes
[
1
][
0
])]
self
.
_canonical_etypes
=
[
(
ntypes
[
0
][
0
],
etypes
[
0
],
ntypes
[
1
][
0
])
]
else
:
self
.
_ntypes
=
ntypes
if
len
(
ntypes
)
==
1
:
src_dst_map
=
None
else
:
src_dst_map
=
find_src_dst_ntypes
(
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
self
.
_is_unibipartite
=
(
src_dst_map
is
not
None
)
src_dst_map
=
find_src_dst_ntypes
(
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
self
.
_is_unibipartite
=
src_dst_map
is
not
None
if
self
.
_is_unibipartite
:
self
.
_srctypes_invmap
,
self
.
_dsttypes_invmap
=
src_dst_map
else
:
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_ntypes
)}
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_ntypes
)
}
self
.
_dsttypes_invmap
=
self
.
_srctypes_invmap
# Handle edge types
self
.
_etypes
=
etypes
if
self
.
_canonical_etypes
is
None
:
if
(
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
)
:
if
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
:
self
.
_canonical_etypes
=
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
else
:
self
.
_canonical_etypes
=
make_canonical_etypes
(
self
.
_etypes
,
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
self
.
_etypes
,
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
# An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate
...
...
@@ -137,21 +188,29 @@ class DGLGraph(object):
self
.
_etype2canonical
[
ety
]
=
tuple
()
else
:
self
.
_etype2canonical
[
ety
]
=
self
.
_canonical_etypes
[
i
]
self
.
_etypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_canonical_etypes
)}
self
.
_etypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_canonical_etypes
)
}
# node and edge frame
if
node_frames
is
None
:
node_frames
=
[
None
]
*
len
(
self
.
_ntypes
)
node_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_nodes
(
i
))
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
node_frames
)]
node_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_nodes
(
i
))
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
node_frames
)
]
self
.
_node_frames
=
node_frames
if
edge_frames
is
None
:
edge_frames
=
[
None
]
*
len
(
self
.
_etypes
)
edge_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_edges
(
i
))
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
edge_frames
)]
edge_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_edges
(
i
))
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
edge_frames
)
]
self
.
_edge_frames
=
edge_frames
def
__setstate__
(
self
,
state
):
...
...
@@ -162,40 +221,60 @@ class DGLGraph(object):
self
.
__dict__
.
update
(
state
)
elif
isinstance
(
state
,
tuple
)
and
len
(
state
)
==
5
:
# DGL == 0.4.3
dgl_warning
(
"The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored."
)
dgl_warning
(
"The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored."
)
self
.
_init
(
*
state
)
elif
isinstance
(
state
,
dict
):
# DGL <= 0.4.2
dgl_warning
(
"The object is pickled with DGL <= 0.4.2. "
"Some of the original attributes are ignored."
)
self
.
_init
(
state
[
'_graph'
],
state
[
'_ntypes'
],
state
[
'_etypes'
],
state
[
'_node_frames'
],
state
[
'_edge_frames'
])
dgl_warning
(
"The object is pickled with DGL <= 0.4.2. "
"Some of the original attributes are ignored."
)
self
.
_init
(
state
[
"_graph"
],
state
[
"_ntypes"
],
state
[
"_etypes"
],
state
[
"_node_frames"
],
state
[
"_edge_frames"
],
)
else
:
raise
IOError
(
"Unrecognized pickle format."
)
def
__repr__
(
self
):
if
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
:
ret
=
(
'Graph(num_nodes={node}, num_edges={edge},
\n
'
' ndata_schemes={ndata}
\n
'
' edata_schemes={edata})'
)
return
ret
.
format
(
node
=
self
.
number_of_nodes
(),
edge
=
self
.
number_of_edges
(),
ndata
=
str
(
self
.
node_attr_schemes
()),
edata
=
str
(
self
.
edge_attr_schemes
()))
ret
=
(
"Graph(num_nodes={node}, num_edges={edge},
\n
"
" ndata_schemes={ndata}
\n
"
" edata_schemes={edata})"
)
return
ret
.
format
(
node
=
self
.
number_of_nodes
(),
edge
=
self
.
number_of_edges
(),
ndata
=
str
(
self
.
node_attr_schemes
()),
edata
=
str
(
self
.
edge_attr_schemes
()),
)
else
:
ret
=
(
'Graph(num_nodes={node},
\n
'
' num_edges={edge},
\n
'
' metagraph={meta})'
)
nnode_dict
=
{
self
.
ntypes
[
i
]
:
self
.
_graph
.
number_of_nodes
(
i
)
for
i
in
range
(
len
(
self
.
ntypes
))}
nedge_dict
=
{
self
.
canonical_etypes
[
i
]
:
self
.
_graph
.
number_of_edges
(
i
)
for
i
in
range
(
len
(
self
.
etypes
))}
ret
=
(
"Graph(num_nodes={node},
\n
"
" num_edges={edge},
\n
"
" metagraph={meta})"
)
nnode_dict
=
{
self
.
ntypes
[
i
]:
self
.
_graph
.
number_of_nodes
(
i
)
for
i
in
range
(
len
(
self
.
ntypes
))
}
nedge_dict
=
{
self
.
canonical_etypes
[
i
]:
self
.
_graph
.
number_of_edges
(
i
)
for
i
in
range
(
len
(
self
.
etypes
))
}
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
return
ret
.
format
(
node
=
nnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
def
__copy__
(
self
):
"""Shallow copy implementation."""
#TODO(minjie): too many states in python; should clean up and lower to C
#
TODO(minjie): too many states in python; should clean up and lower to C
cls
=
type
(
self
)
obj
=
cls
.
__new__
(
cls
)
obj
.
__dict__
.
update
(
self
.
__dict__
)
...
...
@@ -298,14 +377,16 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support add_nodes
if
ntype
is
None
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
'node types.'
)
raise
DGLError
(
"Node type name must be specified if there are more than one "
"node types."
)
# nothing happen
if
num
==
0
:
return
assert
num
>
0
,
'
Number of new nodes should be larger than one.
'
assert
num
>
0
,
"
Number of new nodes should be larger than one.
"
ntid
=
self
.
get_ntype_id
(
ntype
)
# update graph idx
metagraph
=
self
.
_graph
.
metagraph
...
...
@@ -319,23 +400,32 @@ class DGLGraph(object):
relation_graphs
=
[]
for
c_etype
in
self
.
canonical_etypes
:
# src or dst == ntype, update the relation graph
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
or
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
:
u
,
v
=
self
.
edges
(
form
=
'uv'
,
order
=
'eid'
,
etype
=
c_etype
)
if
(
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
or
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
):
u
,
v
=
self
.
edges
(
form
=
"uv"
,
order
=
"eid"
,
etype
=
c_etype
)
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
if
c_etype
[
0
]
==
c_etype
[
2
]
else
2
,
self
.
number_of_nodes
(
c_etype
[
0
])
+
\
(
num
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
else
0
),
self
.
number_of_nodes
(
c_etype
[
2
])
+
\
(
num
if
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
else
0
),
self
.
number_of_nodes
(
c_etype
[
0
])
+
(
num
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
else
0
),
self
.
number_of_nodes
(
c_etype
[
2
])
+
(
num
if
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
else
0
),
u
,
v
,
[
'coo'
,
'csr'
,
'csc'
])
[
"coo"
,
"csr"
,
"csc"
],
)
relation_graphs
.
append
(
hgidx
)
else
:
# do nothing
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
)))
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
)
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
self
.
_graph
=
hgidx
# update data frames
...
...
@@ -452,26 +542,33 @@ class DGLGraph(object):
remove_edges
"""
# TODO(xiangsx): block do not support add_edges
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
'edge types.'
)
raise
DGLError
(
"Edge type name must be specified if there are more than one "
"edge types."
)
# nothing changed
if
len
(
u
)
==
0
or
len
(
v
)
==
0
:
return
assert
len
(
u
)
==
len
(
v
)
or
len
(
u
)
==
1
or
len
(
v
)
==
1
,
\
'The number of source nodes and the number of destination nodes should be same, '
\
'or either the number of source nodes or the number of destination nodes is 1.'
assert
len
(
u
)
==
len
(
v
)
or
len
(
u
)
==
1
or
len
(
v
)
==
1
,
(
"The number of source nodes and the number of destination nodes should be same, "
"or either the number of source nodes or the number of destination nodes is 1."
)
if
len
(
u
)
==
1
and
len
(
v
)
>
1
:
u
=
F
.
full_1d
(
len
(
v
),
F
.
as_scalar
(
u
),
dtype
=
F
.
dtype
(
u
),
ctx
=
F
.
context
(
u
))
u
=
F
.
full_1d
(
len
(
v
),
F
.
as_scalar
(
u
),
dtype
=
F
.
dtype
(
u
),
ctx
=
F
.
context
(
u
)
)
if
len
(
v
)
==
1
and
len
(
u
)
>
1
:
v
=
F
.
full_1d
(
len
(
u
),
F
.
as_scalar
(
v
),
dtype
=
F
.
dtype
(
v
),
ctx
=
F
.
context
(
v
))
v
=
F
.
full_1d
(
len
(
u
),
F
.
as_scalar
(
v
),
dtype
=
F
.
dtype
(
v
),
ctx
=
F
.
context
(
v
)
)
u_type
,
e_type
,
v_type
=
self
.
to_canonical_etype
(
etype
)
# if end nodes of adding edges does not exists
...
...
@@ -501,22 +598,28 @@ class DGLGraph(object):
for
c_etype
in
self
.
canonical_etypes
:
# the target edge type
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
old_u
,
old_v
=
self
.
edges
(
form
=
'
uv
'
,
order
=
'
eid
'
,
etype
=
c_etype
)
old_u
,
old_v
=
self
.
edges
(
form
=
"
uv
"
,
order
=
"
eid
"
,
etype
=
c_etype
)
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
if
u_type
==
v_type
else
2
,
self
.
number_of_nodes
(
u_type
),
self
.
number_of_nodes
(
v_type
),
F
.
cat
([
old_u
,
u
],
dim
=
0
),
F
.
cat
([
old_v
,
v
],
dim
=
0
),
[
'coo'
,
'csr'
,
'csc'
])
[
"coo"
,
"csr"
,
"csc"
],
)
relation_graphs
.
append
(
hgidx
)
else
:
# do nothing
# Note: node range change has been handled in add_nodes()
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
)))
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
)
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
self
.
_graph
=
hgidx
# handle data
...
...
@@ -607,15 +710,19 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_edges
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
\
'edge types.'
)
eids
=
utils
.
prepare_tensor
(
self
,
eids
,
'u'
)
raise
DGLError
(
"Edge type name must be specified if there are more than one "
"edge types."
)
eids
=
utils
.
prepare_tensor
(
self
,
eids
,
"u"
)
if
len
(
eids
)
==
0
:
# no edge to delete
return
assert
self
.
number_of_edges
(
etype
)
>
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
\
'The input eid {} is out of the range [0:{})'
.
format
(
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
self
.
number_of_edges
(
etype
))
assert
self
.
number_of_edges
(
etype
)
>
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)
),
"The input eid {} is out of the range [0:{})"
.
format
(
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
self
.
number_of_edges
(
etype
)
)
# edge_subgraph
edges
=
{}
...
...
@@ -623,25 +730,36 @@ class DGLGraph(object):
for
c_etype
in
self
.
canonical_etypes
:
# the target edge type
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
origin_eids
=
self
.
edges
(
form
=
'
eid
'
,
order
=
'
eid
'
,
etype
=
c_etype
)
origin_eids
=
self
.
edges
(
form
=
"
eid
"
,
order
=
"
eid
"
,
etype
=
c_etype
)
edges
[
c_etype
]
=
utils
.
compensate
(
eids
,
origin_eids
)
else
:
edges
[
c_etype
]
=
self
.
edges
(
form
=
'eid'
,
order
=
'eid'
,
etype
=
c_etype
)
edges
[
c_etype
]
=
self
.
edges
(
form
=
"eid"
,
order
=
"eid"
,
etype
=
c_etype
)
# If the graph is batched, update batch_num_edges
batched
=
self
.
_batch_num_edges
is
not
None
if
batched
:
c_etype
=
(
u_type
,
e_type
,
v_type
)
one_hot_removed_edges
=
F
.
zeros
((
self
.
num_edges
(
c_etype
),),
F
.
float32
,
self
.
device
)
one_hot_removed_edges
=
F
.
scatter_row
(
one_hot_removed_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.
,
F
.
float32
,
self
.
device
))
one_hot_removed_edges
=
F
.
zeros
(
(
self
.
num_edges
(
c_etype
),),
F
.
float32
,
self
.
device
)
one_hot_removed_edges
=
F
.
scatter_row
(
one_hot_removed_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.0
,
F
.
float32
,
self
.
device
),
)
c_etype_batch_num_edges
=
self
.
_batch_num_edges
[
c_etype
]
batch_num_removed_edges
=
segment
.
segment_reduce
(
c_etype_batch_num_edges
,
one_hot_removed_edges
,
reducer
=
'sum'
)
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
\
F
.
astype
(
batch_num_removed_edges
,
F
.
int64
)
sub_g
=
self
.
edge_subgraph
(
edges
,
relabel_nodes
=
False
,
store_ids
=
store_ids
)
batch_num_removed_edges
=
segment
.
segment_reduce
(
c_etype_batch_num_edges
,
one_hot_removed_edges
,
reducer
=
"sum"
)
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
F
.
astype
(
batch_num_removed_edges
,
F
.
int64
)
sub_g
=
self
.
edge_subgraph
(
edges
,
relabel_nodes
=
False
,
store_ids
=
store_ids
)
self
.
_graph
=
sub_g
.
_graph
self
.
_node_frames
=
sub_g
.
_node_frames
self
.
_edge_frames
=
sub_g
.
_edge_frames
...
...
@@ -733,16 +851,20 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_nodes
if
ntype
is
None
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
\
'node types.'
)
raise
DGLError
(
"Node type name must be specified if there are more than one "
"node types."
)
nids
=
utils
.
prepare_tensor
(
self
,
nids
,
'u'
)
nids
=
utils
.
prepare_tensor
(
self
,
nids
,
"u"
)
if
len
(
nids
)
==
0
:
# no node to delete
return
assert
self
.
number_of_nodes
(
ntype
)
>
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
\
'The input nids {} is out of the range [0:{})'
.
format
(
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
self
.
number_of_nodes
(
ntype
))
assert
self
.
number_of_nodes
(
ntype
)
>
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)
),
"The input nids {} is out of the range [0:{})"
.
format
(
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
self
.
number_of_nodes
(
ntype
)
)
ntid
=
self
.
get_ntype_id
(
ntype
)
nodes
=
{}
...
...
@@ -757,18 +879,28 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_nodes
batched
=
self
.
_batch_num_nodes
is
not
None
if
batched
:
one_hot_removed_nodes
=
F
.
zeros
((
self
.
num_nodes
(
target_ntype
),),
F
.
float32
,
self
.
device
)
one_hot_removed_nodes
=
F
.
scatter_row
(
one_hot_removed_nodes
,
nids
,
F
.
full_1d
(
len
(
nids
),
1.
,
F
.
float32
,
self
.
device
))
one_hot_removed_nodes
=
F
.
zeros
(
(
self
.
num_nodes
(
target_ntype
),),
F
.
float32
,
self
.
device
)
one_hot_removed_nodes
=
F
.
scatter_row
(
one_hot_removed_nodes
,
nids
,
F
.
full_1d
(
len
(
nids
),
1.0
,
F
.
float32
,
self
.
device
),
)
c_ntype_batch_num_nodes
=
self
.
_batch_num_nodes
[
target_ntype
]
batch_num_removed_nodes
=
segment
.
segment_reduce
(
c_ntype_batch_num_nodes
,
one_hot_removed_nodes
,
reducer
=
'sum'
)
self
.
_batch_num_nodes
[
target_ntype
]
=
c_ntype_batch_num_nodes
-
\
F
.
astype
(
batch_num_removed_nodes
,
F
.
int64
)
c_ntype_batch_num_nodes
,
one_hot_removed_nodes
,
reducer
=
"sum"
)
self
.
_batch_num_nodes
[
target_ntype
]
=
c_ntype_batch_num_nodes
-
F
.
astype
(
batch_num_removed_nodes
,
F
.
int64
)
# Record old num_edges to check later whether some edges were removed
old_num_edges
=
{
c_etype
:
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
}
old_num_edges
=
{
c_etype
:
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
}
# node_subgraph
# If batch_num_edges is to be updated, record the original edge IDs
...
...
@@ -780,22 +912,36 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_edges
if
batched
:
canonical_etypes
=
[
c_etype
for
c_etype
in
self
.
canonical_etypes
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
!=
old_num_edges
[
c_etype
]]
c_etype
for
c_etype
in
self
.
canonical_etypes
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
!=
old_num_edges
[
c_etype
]
]
for
c_etype
in
canonical_etypes
:
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
==
0
:
self
.
_batch_num_edges
[
c_etype
]
=
F
.
zeros
(
(
self
.
batch_size
,),
F
.
int64
,
self
.
device
)
(
self
.
batch_size
,),
F
.
int64
,
self
.
device
)
continue
one_hot_left_edges
=
F
.
zeros
((
old_num_edges
[
c_etype
],),
F
.
float32
,
self
.
device
)
one_hot_left_edges
=
F
.
zeros
(
(
old_num_edges
[
c_etype
],),
F
.
float32
,
self
.
device
)
eids
=
self
.
edges
[
c_etype
].
data
[
EID
]
one_hot_left_edges
=
F
.
scatter_row
(
one_hot_left_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.
,
F
.
float32
,
self
.
device
))
one_hot_left_edges
=
F
.
scatter_row
(
one_hot_left_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.0
,
F
.
float32
,
self
.
device
),
)
batch_num_left_edges
=
segment
.
segment_reduce
(
self
.
_batch_num_edges
[
c_etype
],
one_hot_left_edges
,
reducer
=
'sum'
)
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
batch_num_left_edges
,
F
.
int64
)
self
.
_batch_num_edges
[
c_etype
],
one_hot_left_edges
,
reducer
=
"sum"
,
)
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
batch_num_left_edges
,
F
.
int64
)
if
batched
and
not
store_ids
:
for
c_ntype
in
self
.
ntypes
:
...
...
@@ -810,7 +956,6 @@ class DGLGraph(object):
self
.
_batch_num_nodes
=
None
self
.
_batch_num_edges
=
None
#################################################################
# Metagraph query
#################################################################
...
...
@@ -1080,7 +1225,9 @@ class DGLGraph(object):
nx_graph
=
self
.
_graph
.
metagraph
.
to_networkx
()
nx_metagraph
=
nx
.
MultiDiGraph
()
for
u_v
in
nx_graph
.
edges
:
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
'id'
]]
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
"id"
]
]
nx_metagraph
.
add_edge
(
srctype
,
dsttype
,
etype
)
return
nx_metagraph
...
...
@@ -1133,8 +1280,10 @@ class DGLGraph(object):
"""
if
etype
is
None
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
'edge types.'
)
raise
DGLError
(
"Edge type name must be specified if there are more than one "
"edge types."
)
etype
=
self
.
etypes
[
0
]
if
isinstance
(
etype
,
tuple
):
return
etype
...
...
@@ -1143,8 +1292,10 @@ class DGLGraph(object):
if
ret
is
None
:
raise
DGLError
(
'Edge type "{}" does not exist.'
.
format
(
etype
))
if
len
(
ret
)
==
0
:
raise
DGLError
(
'Edge type "%s" is ambiguous. Please use canonical edge type '
'in the form of (srctype, etype, dsttype)'
%
etype
)
raise
DGLError
(
'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)"
%
etype
)
return
ret
def
get_ntype_id
(
self
,
ntype
):
...
...
@@ -1164,19 +1315,23 @@ class DGLGraph(object):
"""
if
self
.
is_unibipartite
and
ntype
is
not
None
:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if
ntype
.
startswith
(
'
SRC/
'
):
if
ntype
.
startswith
(
"
SRC/
"
):
return
self
.
get_ntype_id_from_src
(
ntype
[
4
:])
elif
ntype
.
startswith
(
'
DST/
'
):
elif
ntype
.
startswith
(
"
DST/
"
):
return
self
.
get_ntype_id_from_dst
(
ntype
[
4
:])
# If there is no prefix, fallback to normal lookup.
# Lookup both SRC and DST
if
ntype
is
None
:
if
self
.
is_unibipartite
or
len
(
self
.
_srctypes_invmap
)
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
'node types.'
)
raise
DGLError
(
"Node type name must be specified if there are more than one "
"node types."
)
return
0
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
))
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
)
if
ntid
is
None
:
raise
DGLError
(
'Node type "{}" does not exist.'
.
format
(
ntype
))
return
ntid
...
...
@@ -1198,8 +1353,10 @@ class DGLGraph(object):
"""
if
ntype
is
None
:
if
len
(
self
.
_srctypes_invmap
)
!=
1
:
raise
DGLError
(
'SRC node type name must be specified if there are more than one '
'SRC node types.'
)
raise
DGLError
(
"SRC node type name must be specified if there are more than one "
"SRC node types."
)
return
next
(
iter
(
self
.
_srctypes_invmap
.
values
()))
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
...
...
@@ -1223,8 +1380,10 @@ class DGLGraph(object):
"""
if
ntype
is
None
:
if
len
(
self
.
_dsttypes_invmap
)
!=
1
:
raise
DGLError
(
'DST node type name must be specified if there are more than one '
'DST node types.'
)
raise
DGLError
(
"DST node type name must be specified if there are more than one "
"DST node types."
)
return
next
(
iter
(
self
.
_dsttypes_invmap
.
values
()))
ntid
=
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
...
...
@@ -1248,8 +1407,10 @@ class DGLGraph(object):
"""
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
'edge types.'
)
raise
DGLError
(
"Edge type name must be specified if there are more than one "
"edge types."
)
return
0
etid
=
self
.
_etypes_invmap
.
get
(
self
.
to_canonical_etype
(
etype
),
None
)
if
etid
is
None
:
...
...
@@ -1346,17 +1507,23 @@ class DGLGraph(object):
tensor([2, 1])
"""
if
ntype
is
not
None
and
ntype
not
in
self
.
ntypes
:
raise
DGLError
(
'Expect ntype in {}, got {}'
.
format
(
self
.
ntypes
,
ntype
))
raise
DGLError
(
"Expect ntype in {}, got {}"
.
format
(
self
.
ntypes
,
ntype
)
)
if
self
.
_batch_num_nodes
is
None
:
self
.
_batch_num_nodes
=
{}
for
ty
in
self
.
ntypes
:
bnn
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_nodes
(
ty
)],
F
.
int64
),
self
.
device
)
bnn
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_nodes
(
ty
)],
F
.
int64
),
self
.
device
)
self
.
_batch_num_nodes
[
ty
]
=
bnn
if
ntype
is
None
:
if
len
(
self
.
ntypes
)
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
'node types.'
)
raise
DGLError
(
"Node type name must be specified if there are more than one "
"node types."
)
ntype
=
self
.
ntypes
[
0
]
return
self
.
_batch_num_nodes
[
ntype
]
...
...
@@ -1440,8 +1607,10 @@ class DGLGraph(object):
"""
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
ntypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary when there are multiple node types.'
)
val
=
{
self
.
ntypes
[
0
]
:
val
}
raise
DGLError
(
"Must provide a dictionary when there are multiple node types."
)
val
=
{
self
.
ntypes
[
0
]:
val
}
self
.
_batch_num_nodes
=
val
def
batch_num_edges
(
self
,
etype
=
None
):
...
...
@@ -1494,12 +1663,16 @@ class DGLGraph(object):
if
self
.
_batch_num_edges
is
None
:
self
.
_batch_num_edges
=
{}
for
ty
in
self
.
canonical_etypes
:
bne
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_edges
(
ty
)],
F
.
int64
),
self
.
device
)
bne
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_edges
(
ty
)],
F
.
int64
),
self
.
device
)
self
.
_batch_num_edges
[
ty
]
=
bne
if
etype
is
None
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
'edge types.'
)
raise
DGLError
(
"Edge type name must be specified if there are more than one "
"edge types."
)
etype
=
self
.
canonical_etypes
[
0
]
else
:
etype
=
self
.
to_canonical_etype
(
etype
)
...
...
@@ -1585,8 +1758,10 @@ class DGLGraph(object):
"""
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary when there are multiple edge types.'
)
val
=
{
self
.
canonical_etypes
[
0
]
:
val
}
raise
DGLError
(
"Must provide a dictionary when there are multiple edge types."
)
val
=
{
self
.
canonical_etypes
[
0
]:
val
}
self
.
_batch_num_edges
=
val
#################################################################
...
...
@@ -2130,10 +2305,14 @@ class DGLGraph(object):
def
_find_etypes
(
self
,
key
):
etypes
=
[
i
for
i
,
(
srctype
,
etype
,
dsttype
)
in
enumerate
(
self
.
_canonical_etypes
)
if
(
key
[
0
]
==
SLICE_FULL
or
key
[
0
]
==
srctype
)
and
(
key
[
1
]
==
SLICE_FULL
or
key
[
1
]
==
etype
)
and
(
key
[
2
]
==
SLICE_FULL
or
key
[
2
]
==
dsttype
)]
i
for
i
,
(
srctype
,
etype
,
dsttype
)
in
enumerate
(
self
.
_canonical_etypes
)
if
(
key
[
0
]
==
SLICE_FULL
or
key
[
0
]
==
srctype
)
and
(
key
[
1
]
==
SLICE_FULL
or
key
[
1
]
==
etype
)
and
(
key
[
2
]
==
SLICE_FULL
or
key
[
2
]
==
dsttype
)
]
return
etypes
def
__getitem__
(
self
,
key
):
...
...
@@ -2215,9 +2394,11 @@ class DGLGraph(object):
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
err_msg
=
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
+
\
"to get view of one relation type. Use : to slice multiple types (e.g. "
+
\
"G['srctype', :, 'dsttype'])."
err_msg
=
(
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
+
"to get view of one relation type. Use : to slice multiple types (e.g. "
+
"G['srctype', :, 'dsttype'])."
)
orig_key
=
key
if
not
isinstance
(
key
,
tuple
):
...
...
@@ -2229,7 +2410,11 @@ class DGLGraph(object):
etypes
=
self
.
_find_etypes
(
key
)
if
len
(
etypes
)
==
0
:
raise
DGLError
(
'Invalid key "{}". Must be one of the edge types.'
.
format
(
orig_key
))
raise
DGLError
(
'Invalid key "{}". Must be one of the edge types.'
.
format
(
orig_key
)
)
if
len
(
etypes
)
==
1
:
# no ambiguity: return the unitgraph itself
...
...
@@ -2248,7 +2433,9 @@ class DGLGraph(object):
new_etypes
=
[
etype
]
new_eframes
=
[
self
.
_edge_frames
[
etid
]]
return
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
return
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
else
:
flat
=
self
.
_graph
.
flatten_relations
(
etypes
)
new_g
=
flat
.
graph
...
...
@@ -2262,7 +2449,8 @@ class DGLGraph(object):
new_ntypes
.
append
(
combine_names
(
self
.
ntypes
,
dtids
))
new_nframes
=
[
combine_frames
(
self
.
_node_frames
,
stids
),
combine_frames
(
self
.
_node_frames
,
dtids
)]
combine_frames
(
self
.
_node_frames
,
dtids
),
]
else
:
assert
np
.
array_equal
(
stids
,
dtids
)
new_nframes
=
[
combine_frames
(
self
.
_node_frames
,
stids
)]
...
...
@@ -2270,16 +2458,28 @@ class DGLGraph(object):
new_eframes
=
[
combine_frames
(
self
.
_edge_frames
,
etids
)]
# create new heterograph
new_hg
=
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
new_hg
=
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
src
=
new_ntypes
[
0
]
dst
=
new_ntypes
[
1
]
if
new_g
.
number_of_ntypes
()
==
2
else
src
# put the parent node/edge type and IDs
new_hg
.
nodes
[
src
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srctype
)
new_hg
.
nodes
[
src
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srcid
)
new_hg
.
nodes
[
dst
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dsttype
)
new_hg
.
nodes
[
dst
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dstid
)
new_hg
.
edata
[
ETYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_etype
)
new_hg
.
nodes
[
src
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srctype
)
new_hg
.
nodes
[
src
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srcid
)
new_hg
.
nodes
[
dst
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dsttype
)
new_hg
.
nodes
[
dst
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dstid
)
new_hg
.
edata
[
ETYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_etype
)
new_hg
.
edata
[
EID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_eid
)
return
new_hg
...
...
@@ -2331,7 +2531,12 @@ class DGLGraph(object):
12
"""
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
ntid
)
for
ntid
in
range
(
len
(
self
.
ntypes
))])
return
sum
(
[
self
.
_graph
.
number_of_nodes
(
ntid
)
for
ntid
in
range
(
len
(
self
.
ntypes
))
]
)
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id
(
ntype
))
...
...
@@ -2396,10 +2601,16 @@ class DGLGraph(object):
7
"""
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
nty
))
for
nty
in
self
.
srctypes
])
return
sum
(
[
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
nty
))
for
nty
in
self
.
srctypes
]
)
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
ntype
)
)
def
number_of_dst_nodes
(
self
,
ntype
=
None
):
"""Alias of :func:`num_dst_nodes`"""
...
...
@@ -2462,10 +2673,16 @@ class DGLGraph(object):
12
"""
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
nty
))
for
nty
in
self
.
dsttypes
])
return
sum
(
[
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
nty
))
for
nty
in
self
.
dsttypes
]
)
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
ntype
)
)
def
number_of_edges
(
self
,
etype
=
None
):
"""Alias of :func:`num_edges`"""
...
...
@@ -2522,8 +2739,12 @@ class DGLGraph(object):
3
"""
if
etype
is
None
:
return
sum
([
self
.
_graph
.
number_of_edges
(
etid
)
for
etid
in
range
(
len
(
self
.
canonical_etypes
))])
return
sum
(
[
self
.
_graph
.
number_of_edges
(
etid
)
for
etid
in
range
(
len
(
self
.
canonical_etypes
))
]
)
else
:
return
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
etype
))
...
...
@@ -2708,10 +2929,11 @@ class DGLGraph(object):
tensor([False, True, True])
"""
vid_tensor
=
utils
.
prepare_tensor
(
self
,
vid
,
"vid"
)
if
len
(
vid_tensor
)
>
0
and
F
.
as_scalar
(
F
.
min
(
vid_tensor
,
0
))
<
0
<
len
(
vid_tensor
):
raise
DGLError
(
'All IDs must be non-negative integers.'
)
ret
=
self
.
_graph
.
has_nodes
(
self
.
get_ntype_id
(
ntype
),
vid_tensor
)
if
len
(
vid_tensor
)
>
0
and
F
.
as_scalar
(
F
.
min
(
vid_tensor
,
0
))
<
0
<
len
(
vid_tensor
):
raise
DGLError
(
"All IDs must be non-negative integers."
)
ret
=
self
.
_graph
.
has_nodes
(
self
.
get_ntype_id
(
ntype
),
vid_tensor
)
if
isinstance
(
vid
,
numbers
.
Integral
):
return
bool
(
F
.
as_scalar
(
ret
))
else
:
...
...
@@ -2793,15 +3015,19 @@ class DGLGraph(object):
tensor([True, True])
"""
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u_tensor
):
raise
DGLError
(
'u contains invalid node IDs'
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v_tensor
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v_tensor
):
raise
DGLError
(
'v contains invalid node IDs'
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
)
)
!=
len
(
u_tensor
):
raise
DGLError
(
"u contains invalid node IDs"
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v_tensor
,
ntype
=
dsttype
),
dim
=
0
)
)
!=
len
(
v_tensor
):
raise
DGLError
(
"v contains invalid node IDs"
)
ret
=
self
.
_graph
.
has_edges_between
(
self
.
get_etype_id
(
etype
),
u_tensor
,
v_tensor
)
self
.
get_etype_id
(
etype
),
u_tensor
,
v_tensor
)
if
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
):
return
bool
(
F
.
as_scalar
(
ret
))
else
:
...
...
@@ -2863,7 +3089,7 @@ class DGLGraph(object):
successors
"""
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
-
1
]):
raise
DGLError
(
'
Non-existing node ID {}
'
.
format
(
v
))
raise
DGLError
(
"
Non-existing node ID {}
"
.
format
(
v
))
return
self
.
_graph
.
predecessors
(
self
.
get_etype_id
(
etype
),
v
)
def
successors
(
self
,
v
,
etype
=
None
):
...
...
@@ -2921,7 +3147,7 @@ class DGLGraph(object):
predecessors
"""
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
0
]):
raise
DGLError
(
'
Non-existing node ID {}
'
.
format
(
v
))
raise
DGLError
(
"
Non-existing node ID {}
"
.
format
(
v
))
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
def
edge_ids
(
self
,
u
,
v
,
return_uv
=
False
,
etype
=
None
):
...
...
@@ -3018,14 +3244,20 @@ class DGLGraph(object):
... etype=('user', 'follows', 'game'))
tensor([1, 2])
"""
is_int
=
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
)
is_int
=
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
)
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
raise
DGLError
(
'u contains invalid node IDs'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'v contains invalid node IDs'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
raise
DGLError
(
"u contains invalid node IDs"
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
"v contains invalid node IDs"
)
if
return_uv
:
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
...
...
@@ -3035,9 +3267,13 @@ class DGLGraph(object):
if
F
.
as_scalar
(
F
.
sum
(
is_neg_one
,
0
)):
# Raise error since some (u, v) pair is not a valid edge.
idx
=
F
.
nonzero_1d
(
is_neg_one
)
raise
DGLError
(
"Error: (%d, %d) does not form a valid edge."
%
(
F
.
as_scalar
(
F
.
gather_row
(
u
,
idx
)),
F
.
as_scalar
(
F
.
gather_row
(
v
,
idx
))))
raise
DGLError
(
"Error: (%d, %d) does not form a valid edge."
%
(
F
.
as_scalar
(
F
.
gather_row
(
u
,
idx
)),
F
.
as_scalar
(
F
.
gather_row
(
v
,
idx
)),
)
)
return
F
.
as_scalar
(
eid
)
if
is_int
else
eid
def
find_edges
(
self
,
eid
,
etype
=
None
):
...
...
@@ -3096,14 +3332,14 @@ class DGLGraph(object):
>>> hg.find_edges(torch.tensor([1, 0]), 'plays')
(tensor([4, 3]), tensor([6, 5]))
"""
eid
=
utils
.
prepare_tensor
(
self
,
eid
,
'
eid
'
)
eid
=
utils
.
prepare_tensor
(
self
,
eid
,
"
eid
"
)
if
len
(
eid
)
>
0
:
min_eid
=
F
.
as_scalar
(
F
.
min
(
eid
,
0
))
if
min_eid
<
0
:
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
min_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
min_eid
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
eid
,
0
))
if
max_eid
>=
self
.
num_edges
(
etype
):
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
max_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
max_eid
))
if
len
(
eid
)
==
0
:
empty
=
F
.
copy_to
(
F
.
tensor
([],
self
.
idtype
),
self
.
device
)
...
...
@@ -3111,7 +3347,7 @@ class DGLGraph(object):
src
,
dst
,
_
=
self
.
_graph
.
find_edges
(
self
.
get_etype_id
(
etype
),
eid
)
return
src
,
dst
def
in_edges
(
self
,
v
,
form
=
'
uv
'
,
etype
=
None
):
def
in_edges
(
self
,
v
,
form
=
"
uv
"
,
etype
=
None
):
"""Return the incoming edges of the given nodes.
Parameters
...
...
@@ -3184,18 +3420,20 @@ class DGLGraph(object):
edges
out_edges
"""
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
src
,
dst
,
eid
=
self
.
_graph
.
in_edges
(
self
.
get_etype_id
(
etype
),
v
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
out_edges
(
self
,
u
,
form
=
'
uv
'
,
etype
=
None
):
def
out_edges
(
self
,
u
,
form
=
"
uv
"
,
etype
=
None
):
"""Return the outgoing edges of the given nodes.
Parameters
...
...
@@ -3268,21 +3506,25 @@ class DGLGraph(object):
edges
in_edges
"""
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
srctype
,
_
,
_
=
self
.
to_canonical_etype
(
etype
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
raise
DGLError
(
'u contains invalid node IDs'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
raise
DGLError
(
"u contains invalid node IDs"
)
src
,
dst
,
eid
=
self
.
_graph
.
out_edges
(
self
.
get_etype_id
(
etype
),
u
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
all_edges
(
self
,
form
=
'
uv
'
,
order
=
'
eid
'
,
etype
=
None
):
def
all_edges
(
self
,
form
=
"
uv
"
,
order
=
"
eid
"
,
etype
=
None
):
"""Return all edges with the specified edge type.
Parameters
...
...
@@ -3353,14 +3595,16 @@ class DGLGraph(object):
out_edges
"""
src
,
dst
,
eid
=
self
.
_graph
.
edges
(
self
.
get_etype_id
(
etype
),
order
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
"""Return the in-degree(s) of the given nodes.
...
...
@@ -3431,7 +3675,7 @@ class DGLGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
if
is_all
(
v
):
v
=
self
.
dstnodes
(
dsttype
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
deg
=
self
.
_graph
.
in_degrees
(
etid
,
v_tensor
)
if
isinstance
(
v
,
numbers
.
Integral
):
return
F
.
as_scalar
(
deg
)
...
...
@@ -3507,16 +3751,20 @@ class DGLGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
if
is_all
(
u
):
u
=
self
.
srcnodes
(
srctype
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u_tensor
):
raise
DGLError
(
'u contains invalid node IDs'
)
deg
=
self
.
_graph
.
out_degrees
(
etid
,
utils
.
prepare_tensor
(
self
,
u
,
'u'
))
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
)
)
!=
len
(
u_tensor
):
raise
DGLError
(
"u contains invalid node IDs"
)
deg
=
self
.
_graph
.
out_degrees
(
etid
,
utils
.
prepare_tensor
(
self
,
u
,
"u"
))
if
isinstance
(
u
,
numbers
.
Integral
):
return
F
.
as_scalar
(
deg
)
else
:
return
deg
def
adjacency_matrix
(
self
,
transpose
=
False
,
ctx
=
F
.
cpu
(),
scipy_fmt
=
None
,
etype
=
None
):
def
adjacency_matrix
(
self
,
transpose
=
False
,
ctx
=
F
.
cpu
(),
scipy_fmt
=
None
,
etype
=
None
):
"""Alias of :meth:`adj`"""
return
self
.
adj
(
transpose
,
ctx
,
scipy_fmt
,
etype
)
...
...
@@ -3586,7 +3834,9 @@ class DGLGraph(object):
if
scipy_fmt
is
None
:
return
self
.
_graph
.
adjacency_matrix
(
etid
,
transpose
,
ctx
)[
0
]
else
:
return
self
.
_graph
.
adjacency_matrix_scipy
(
etid
,
transpose
,
scipy_fmt
,
False
)
return
self
.
_graph
.
adjacency_matrix_scipy
(
etid
,
transpose
,
scipy_fmt
,
False
)
def
adj_sparse
(
self
,
fmt
,
etype
=
None
):
"""Return the adjacency matrix of edges of the given edge type as tensors of
...
...
@@ -3629,9 +3879,9 @@ class DGLGraph(object):
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))
"""
etid
=
self
.
get_etype_id
(
etype
)
if
fmt
==
'
csc
'
:
if
fmt
==
"
csc
"
:
# The first two elements are number of rows and columns
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
True
,
'
csr
'
)[
2
:]
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
True
,
"
csr
"
)[
2
:]
else
:
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
...
...
@@ -4024,26 +4274,36 @@ class DGLGraph(object):
if
is_all
(
u
):
num_nodes
=
self
.
_graph
.
number_of_nodes
(
ntid
)
else
:
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
num_nodes
=
len
(
u
)
for
key
,
val
in
data
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!=
num_nodes
:
raise
DGLError
(
'Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.'
%
(
nfeats
,
num_nodes
))
raise
DGLError
(
"Expect number of features to match number of nodes (len(u))."
" Got %d and %d instead."
%
(
nfeats
,
num_nodes
)
)
if
F
.
context
(
val
)
!=
self
.
device
:
raise
DGLError
(
'Cannot assign node feature "{}" on device {} to a graph on'
' device {}. Call DGLGraph.to() to copy the graph to the'
' same device.'
.
format
(
key
,
F
.
context
(
val
),
self
.
device
))
raise
DGLError
(
'Cannot assign node feature "{}" on device {} to a graph on'
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device."
.
format
(
key
,
F
.
context
(
val
),
self
.
device
)
)
# To prevent users from doing things like:
#
# g.pin_memory_()
# g.ndata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if
self
.
is_pinned
()
and
F
.
context
(
val
)
==
'cpu'
and
not
F
.
is_pinned
(
val
):
raise
DGLError
(
'Pinned graph requires the node data to be pinned as well. '
'Please pin the node data before assignment.'
)
if
(
self
.
is_pinned
()
and
F
.
context
(
val
)
==
"cpu"
and
not
F
.
is_pinned
(
val
)
):
raise
DGLError
(
"Pinned graph requires the node data to be pinned as well. "
"Please pin the node data before assignment."
)
if
is_all
(
u
):
self
.
_node_frames
[
ntid
].
update
(
data
)
...
...
@@ -4070,7 +4330,7 @@ class DGLGraph(object):
if
is_all
(
u
):
return
self
.
_node_frames
[
ntid
]
else
:
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
return
self
.
_node_frames
[
ntid
].
subframe
(
u
)
def
_pop_n_repr
(
self
,
ntid
,
key
):
...
...
@@ -4116,12 +4376,14 @@ class DGLGraph(object):
"""
# parse argument
if
not
is_all
(
edges
):
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
# sanity check
if
not
utils
.
is_dict_like
(
data
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
data
))
raise
DGLError
(
"Expect dictionary type for feature data."
' Got "%s" instead.'
%
type
(
data
)
)
if
is_all
(
edges
):
num_edges
=
self
.
_graph
.
number_of_edges
(
etid
)
...
...
@@ -4130,21 +4392,31 @@ class DGLGraph(object):
for
key
,
val
in
data
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!=
num_edges
:
raise
DGLError
(
'Expect number of features to match number of edges.'
' Got %d and %d instead.'
%
(
nfeats
,
num_edges
))
raise
DGLError
(
"Expect number of features to match number of edges."
" Got %d and %d instead."
%
(
nfeats
,
num_edges
)
)
if
F
.
context
(
val
)
!=
self
.
device
:
raise
DGLError
(
'Cannot assign edge feature "{}" on device {} to a graph on'
' device {}. Call DGLGraph.to() to copy the graph to the'
' same device.'
.
format
(
key
,
F
.
context
(
val
),
self
.
device
))
raise
DGLError
(
'Cannot assign edge feature "{}" on device {} to a graph on'
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device."
.
format
(
key
,
F
.
context
(
val
),
self
.
device
)
)
# To prevent users from doing things like:
#
# g.pin_memory_()
# g.edata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if
self
.
is_pinned
()
and
F
.
context
(
val
)
==
'cpu'
and
not
F
.
is_pinned
(
val
):
raise
DGLError
(
'Pinned graph requires the edge data to be pinned as well. '
'Please pin the edge data before assignment.'
)
if
(
self
.
is_pinned
()
and
F
.
context
(
val
)
==
"cpu"
and
not
F
.
is_pinned
(
val
)
):
raise
DGLError
(
"Pinned graph requires the edge data to be pinned as well. "
"Please pin the edge data before assignment."
)
# set
if
is_all
(
edges
):
...
...
@@ -4172,7 +4444,7 @@ class DGLGraph(object):
if
is_all
(
edges
):
return
self
.
_edge_frames
[
etid
]
else
:
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
return
self
.
_edge_frames
[
etid
].
subframe
(
eid
)
def
_pop_e_repr
(
self
,
etid
,
key
):
...
...
@@ -4256,7 +4528,7 @@ class DGLGraph(object):
if
is_all
(
v
):
v_id
=
self
.
nodes
(
ntype
)
else
:
v_id
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v_id
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
...
...
@@ -4346,16 +4618,18 @@ class DGLGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
etype
=
self
.
canonical_etypes
[
etid
]
g
=
self
if
etype
is
None
else
self
[
etype
]
else
:
# heterogeneous graph with number of relation types > 1
else
:
# heterogeneous graph with number of relation types > 1
if
not
core
.
is_builtin
(
func
):
raise
DGLError
(
"User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead."
)
raise
DGLError
(
"User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead."
)
g
=
self
if
is_all
(
edges
):
eid
=
ALL
else
:
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
if
core
.
is_builtin
(
func
):
if
not
is_all
(
eid
):
g
=
g
.
edge_subgraph
(
eid
,
relabel_nodes
=
False
)
...
...
@@ -4375,12 +4649,9 @@ class DGLGraph(object):
edata_tensor
[
key
]
=
out_tensor_tuples
[
etid
]
self
.
_set_e_repr
(
etid
,
eid
,
edata_tensor
)
def
send_and_recv
(
self
,
edges
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
def
send_and_recv
(
self
,
edges
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Send messages along the specified edges and reduce them on
the destination nodes to update their features.
...
...
@@ -4493,7 +4764,7 @@ class DGLGraph(object):
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
etype
=
self
.
canonical_etypes
[
etid
]
# edge IDs
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
if
len
(
eid
)
==
0
:
# no computation
return
...
...
@@ -4502,15 +4773,13 @@ class DGLGraph(object):
g
=
self
if
etype
is
None
else
self
[
etype
]
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
u
,
v
,
eid
)
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
pull
(
self
,
v
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
def
pull
(
self
,
v
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features.
...
...
@@ -4588,7 +4857,7 @@ class DGLGraph(object):
[1.],
[1.]])
"""
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
len
(
v
)
==
0
:
# no computation
return
...
...
@@ -4597,18 +4866,18 @@ class DGLGraph(object):
etype
=
self
.
canonical_etypes
[
etid
]
g
=
self
if
etype
is
None
else
self
[
etype
]
# call message passing on subgraph
src
,
dst
,
eid
=
g
.
in_edges
(
v
,
form
=
'all'
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
)
src
,
dst
,
eid
=
g
.
in_edges
(
v
,
form
=
"all"
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
)
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
push
(
self
,
u
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
def
push
(
self
,
u
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Send message from the specified node(s) to their successors
along the specified edge type and update their node features.
...
...
@@ -4679,14 +4948,14 @@ class DGLGraph(object):
[0.],
[0.]])
"""
edges
=
self
.
out_edges
(
u
,
form
=
'
eid
'
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
def
update_all
(
self
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
edges
=
self
.
out_edges
(
u
,
form
=
"
eid
"
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
def
update_all
(
self
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type.
...
...
@@ -4778,23 +5047,37 @@ class DGLGraph(object):
etype
=
self
.
canonical_etypes
[
etid
]
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
g
=
self
if
etype
is
None
else
self
[
etype
]
ndata
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
'min'
,
'max'
]
and
ndata
:
ndata
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
if
(
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
"min"
,
"max"
]
and
ndata
):
# Replace infinity with zero for isolated nodes
key
=
list
(
ndata
.
keys
())[
0
]
ndata
[
key
]
=
F
.
replace_inf_with_zero
(
ndata
[
key
])
self
.
_set_n_repr
(
dtid
,
ALL
,
ndata
)
else
:
# heterogeneous graph with number of relation types > 1
if
not
core
.
is_builtin
(
message_func
)
or
not
core
.
is_builtin
(
reduce_func
):
raise
DGLError
(
"User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead."
)
if
reduce_func
.
name
in
[
'mean'
]:
raise
NotImplementedError
(
"Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead."
)
else
:
# heterogeneous graph with number of relation types > 1
if
not
core
.
is_builtin
(
message_func
)
or
not
core
.
is_builtin
(
reduce_func
):
raise
DGLError
(
"User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead."
)
if
reduce_func
.
name
in
[
"mean"
]:
raise
NotImplementedError
(
"Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead."
)
g
=
self
all_out
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
all_out
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
key
=
list
(
all_out
.
keys
())[
0
]
out_tensor_tuples
=
all_out
[
key
]
...
...
@@ -4802,7 +5085,10 @@ class DGLGraph(object):
for
_
,
_
,
dsttype
in
g
.
canonical_etypes
:
dtid
=
g
.
get_ntype_id
(
dsttype
)
dst_tensor
[
key
]
=
out_tensor_tuples
[
dtid
]
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
'min'
,
'max'
]:
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
"min"
,
"max"
,
]:
dst_tensor
[
key
]
=
F
.
replace_inf_with_zero
(
dst_tensor
[
key
])
self
.
_node_frames
[
dtid
].
update
(
dst_tensor
)
...
...
@@ -4902,36 +5188,44 @@ class DGLGraph(object):
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
args
=
pad_tuple
(
args
,
3
)
if
args
is
None
:
raise
DGLError
(
'Invalid arguments for edge type "{}". Should be '
'(msg_func, reduce_func, [apply_node_func])'
.
format
(
etype
))
raise
DGLError
(
'Invalid arguments for edge type "{}". Should be '
"(msg_func, reduce_func, [apply_node_func])"
.
format
(
etype
)
)
mfunc
,
rfunc
,
afunc
=
args
g
=
self
if
etype
is
None
else
self
[
etype
]
all_out
[
dtid
].
append
(
core
.
message_passing
(
g
,
mfunc
,
rfunc
,
afunc
))
merge_order
[
dtid
].
append
(
etid
)
# use edge type id as merge order hint
merge_order
[
dtid
].
append
(
etid
)
# use edge type id as merge order hint
for
dtid
,
frames
in
all_out
.
items
():
# merge by cross_reducer
out
=
reduce_dict_data
(
frames
,
cross_reducer
,
merge_order
[
dtid
])
# Replace infinity with zero for isolated nodes when reducer is min/max
if
core
.
is_builtin
(
rfunc
)
and
rfunc
.
name
in
[
'
min
'
,
'
max
'
]:
if
core
.
is_builtin
(
rfunc
)
and
rfunc
.
name
in
[
"
min
"
,
"
max
"
]:
key
=
list
(
out
.
keys
())[
0
]
out
[
key
]
=
F
.
replace_inf_with_zero
(
out
[
key
])
if
out
[
key
]
is
not
None
else
None
out
[
key
]
=
(
F
.
replace_inf_with_zero
(
out
[
key
])
if
out
[
key
]
is
not
None
else
None
)
self
.
_node_frames
[
dtid
].
update
(
out
)
# apply
if
apply_node_func
is
not
None
:
self
.
apply_nodes
(
apply_node_func
,
ALL
,
self
.
ntypes
[
dtid
])
#################################################################
# Message propagation
#################################################################
def
prop_nodes
(
self
,
nodes_generator
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
def
prop_nodes
(
self
,
nodes_generator
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
):
"""Propagate messages using graph traversal by sequentially triggering
:func:`pull()` on nodes.
...
...
@@ -4987,14 +5281,22 @@ class DGLGraph(object):
prop_edges
"""
for
node_frontier
in
nodes_generator
:
self
.
pull
(
node_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
def
prop_edges
(
self
,
edges_generator
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
self
.
pull
(
node_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
,
)
def
prop_edges
(
self
,
edges_generator
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
,
):
"""Propagate messages using graph traversal by sequentially triggering
:func:`send_and_recv()` on edges.
...
...
@@ -5051,8 +5353,13 @@ class DGLGraph(object):
prop_nodes
"""
for
edge_frontier
in
edges_generator
:
self
.
send_and_recv
(
edge_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
self
.
send_and_recv
(
edge_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
,
)
#################################################################
# Misc
...
...
@@ -5127,14 +5434,16 @@ class DGLGraph(object):
"""
if
is_all
(
nodes
):
nodes
=
self
.
nodes
(
ntype
)
v
=
utils
.
prepare_tensor
(
self
,
nodes
,
'
nodes
'
)
v
=
utils
.
prepare_tensor
(
self
,
nodes
,
"
nodes
"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
ntype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'
v contains invalid node IDs
'
)
raise
DGLError
(
"
v contains invalid node IDs
"
)
with
self
.
local_scope
():
self
.
apply_nodes
(
lambda
nbatch
:
{
'_mask'
:
predicate
(
nbatch
)},
nodes
,
ntype
)
self
.
apply_nodes
(
lambda
nbatch
:
{
"_mask"
:
predicate
(
nbatch
)},
nodes
,
ntype
)
ntype
=
self
.
ntypes
[
0
]
if
ntype
is
None
else
ntype
mask
=
self
.
nodes
[
ntype
].
data
[
'
_mask
'
]
mask
=
self
.
nodes
[
ntype
].
data
[
"
_mask
"
]
if
is_all
(
nodes
):
return
F
.
nonzero_1d
(
mask
)
else
:
...
...
@@ -5221,34 +5530,40 @@ class DGLGraph(object):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
raise
DGLError
(
'edges[0] contains invalid node IDs'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'edges[1] contains invalid node IDs'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
)
)
!=
len
(
u
):
raise
DGLError
(
"edges[0] contains invalid node IDs"
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
)
)
!=
len
(
v
):
raise
DGLError
(
"edges[1] contains invalid node IDs"
)
elif
isinstance
(
edges
,
Iterable
)
or
F
.
is_tensor
(
edges
):
edges
=
utils
.
prepare_tensor
(
self
,
edges
,
'
edges
'
)
edges
=
utils
.
prepare_tensor
(
self
,
edges
,
"
edges
"
)
min_eid
=
F
.
as_scalar
(
F
.
min
(
edges
,
0
))
if
len
(
edges
)
>
0
>
min_eid
:
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
min_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
min_eid
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
edges
,
0
))
if
len
(
edges
)
>
0
and
max_eid
>=
self
.
num_edges
(
etype
):
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
max_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
max_eid
))
else
:
raise
ValueError
(
'
Unsupported type of edges:
'
,
type
(
edges
))
raise
ValueError
(
"
Unsupported type of edges:
"
,
type
(
edges
))
with
self
.
local_scope
():
self
.
apply_edges
(
lambda
ebatch
:
{
'_mask'
:
predicate
(
ebatch
)},
edges
,
etype
)
self
.
apply_edges
(
lambda
ebatch
:
{
"_mask"
:
predicate
(
ebatch
)},
edges
,
etype
)
etype
=
self
.
canonical_etypes
[
0
]
if
etype
is
None
else
etype
mask
=
self
.
edges
[
etype
].
data
[
'
_mask
'
]
mask
=
self
.
edges
[
etype
].
data
[
"
_mask
"
]
if
is_all
(
edges
):
return
F
.
nonzero_1d
(
mask
)
else
:
if
isinstance
(
edges
,
tuple
):
e
=
self
.
edge_ids
(
edges
[
0
],
edges
[
1
],
etype
=
etype
)
else
:
e
=
utils
.
prepare_tensor
(
self
,
edges
,
'
edges
'
)
e
=
utils
.
prepare_tensor
(
self
,
edges
,
"
edges
"
)
return
F
.
boolean_mask
(
e
,
F
.
gather_row
(
mask
,
e
))
@
property
...
...
@@ -5347,12 +5662,16 @@ class DGLGraph(object):
# 2. Copy misc info
if
self
.
_batch_num_nodes
is
not
None
:
new_bnn
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_nodes
.
items
()}
new_bnn
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_nodes
.
items
()
}
ret
.
_batch_num_nodes
=
new_bnn
if
self
.
_batch_num_edges
is
not
None
:
new_bne
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_edges
.
items
()}
new_bne
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_edges
.
items
()
}
ret
.
_batch_num_edges
=
new_bne
return
ret
...
...
@@ -5432,8 +5751,10 @@ class DGLGraph(object):
tensor([0, 1, 1])
"""
if
not
self
.
_graph
.
is_pinned
():
if
F
.
device_type
(
self
.
device
)
!=
'cpu'
:
raise
DGLError
(
"The graph structure must be on CPU to be pinned."
)
if
F
.
device_type
(
self
.
device
)
!=
"cpu"
:
raise
DGLError
(
"The graph structure must be on CPU to be pinned."
)
self
.
_graph
.
pin_memory_
()
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
for
col
in
frame
.
_columns
.
values
():
...
...
@@ -5484,9 +5805,9 @@ class DGLGraph(object):
DGLGraph
self.
"""
if
F
.
get_preferred_backend
()
!=
'
pytorch
'
:
if
F
.
get_preferred_backend
()
!=
"
pytorch
"
:
raise
DGLError
(
"record_stream only support the PyTorch backend."
)
if
F
.
device_type
(
self
.
device
)
!=
'
cuda
'
:
if
F
.
device_type
(
self
.
device
)
!=
"
cuda
"
:
raise
DGLError
(
"The graph must be on GPU to be recorded."
)
self
.
_graph
.
record_stream
(
stream
)
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
...
...
@@ -5510,15 +5831,24 @@ class DGLGraph(object):
# Clone the graph structure
meta_edges
=
[]
for
s_ntype
,
_
,
d_ntype
in
self
.
canonical_etypes
:
meta_edges
.
append
((
self
.
get_ntype_id
(
s_ntype
),
self
.
get_ntype_id
(
d_ntype
)))
meta_edges
.
append
(
(
self
.
get_ntype_id
(
s_ntype
),
self
.
get_ntype_id
(
d_ntype
))
)
metagraph
=
graph_index
.
from_edge_list
(
meta_edges
,
True
)
# rebuild graph idx
num_nodes_per_type
=
[
self
.
number_of_nodes
(
c_ntype
)
for
c_ntype
in
self
.
ntypes
]
relation_graphs
=
[
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
]
num_nodes_per_type
=
[
self
.
number_of_nodes
(
c_ntype
)
for
c_ntype
in
self
.
ntypes
]
relation_graphs
=
[
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
]
ret
.
_graph
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
# Clone the frames
ret
.
_node_frames
=
[
fr
.
clone
()
for
fr
in
self
.
_node_frames
]
...
...
@@ -5819,7 +6149,7 @@ class DGLGraph(object):
return
ret
# TODO: Formats should not be specified, just saving all the materialized formats
def
shared_memory
(
self
,
name
,
formats
=
(
'
coo
'
,
'
csr
'
,
'
csc
'
)):
def
shared_memory
(
self
,
name
,
formats
=
(
"
coo
"
,
"
csr
"
,
"
csc
"
)):
"""Return a copy of this graph in shared memory, without node data or edge data.
It moves the graph index to shared memory and returns a DGLGraph object which
...
...
@@ -5843,11 +6173,16 @@ class DGLGraph(object):
if
isinstance
(
formats
,
str
):
formats
=
[
formats
]
for
fmt
in
formats
:
assert
fmt
in
(
"coo"
,
"csr"
,
"csc"
),
'{} is not coo, csr or csc'
.
format
(
fmt
)
gidx
=
self
.
_graph
.
shared_memory
(
name
,
self
.
ntypes
,
self
.
etypes
,
formats
)
assert
fmt
in
(
"coo"
,
"csr"
,
"csc"
,
),
"{} is not coo, csr or csc"
.
format
(
fmt
)
gidx
=
self
.
_graph
.
shared_memory
(
name
,
self
.
ntypes
,
self
.
etypes
,
formats
)
return
DGLGraph
(
gidx
,
self
.
ntypes
,
self
.
etypes
)
def
long
(
self
):
"""Cast the graph to one with idtype int64
...
...
@@ -5948,10 +6283,12 @@ class DGLGraph(object):
"""
return
self
.
astype
(
F
.
int32
)
############################################################
# Internal APIs
############################################################
def
make_canonical_etypes
(
etypes
,
ntypes
,
metagraph
):
"""Internal function to convert etype name to (srctype, etype, dsttype)
...
...
@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
"""
# sanity check
if
len
(
etypes
)
!=
metagraph
.
number_of_edges
():
raise
DGLError
(
'Length of edge type list must match the number of '
'edges in the metagraph. {} vs {}'
.
format
(
len
(
etypes
),
metagraph
.
number_of_edges
()))
raise
DGLError
(
"Length of edge type list must match the number of "
"edges in the metagraph. {} vs {}"
.
format
(
len
(
etypes
),
metagraph
.
number_of_edges
()
)
)
if
len
(
ntypes
)
!=
metagraph
.
number_of_nodes
():
raise
DGLError
(
'Length of nodes type list must match the number of '
'nodes in the metagraph. {} vs {}'
.
format
(
len
(
ntypes
),
metagraph
.
number_of_nodes
()))
if
(
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
):
raise
DGLError
(
"Length of nodes type list must match the number of "
"nodes in the metagraph. {} vs {}"
.
format
(
len
(
ntypes
),
metagraph
.
number_of_nodes
()
)
)
if
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
:
return
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
src
,
dst
,
eid
=
metagraph
.
edges
(
order
=
"eid"
)
rst
=
[(
ntypes
[
sid
],
etypes
[
eid
],
ntypes
[
did
])
for
sid
,
did
,
eid
in
zip
(
src
,
dst
,
eid
)]
rst
=
[
(
ntypes
[
sid
],
etypes
[
eid
],
ntypes
[
did
])
for
sid
,
did
,
eid
in
zip
(
src
,
dst
,
eid
)
]
return
rst
def
find_src_dst_ntypes
(
ntypes
,
metagraph
):
"""Internal function to split ntypes into SRC and DST categories.
...
...
@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph):
return
None
else
:
src
,
dst
=
ret
srctypes
=
{
ntypes
[
tid
]
:
tid
for
tid
in
src
}
dsttypes
=
{
ntypes
[
tid
]
:
tid
for
tid
in
dst
}
srctypes
=
{
ntypes
[
tid
]:
tid
for
tid
in
src
}
dsttypes
=
{
ntypes
[
tid
]:
tid
for
tid
in
dst
}
return
srctypes
,
dsttypes
def
pad_tuple
(
tup
,
length
,
pad_val
=
None
):
"""Pad the given tuple to the given length.
...
...
@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None):
Return None if pad fails.
"""
if
not
isinstance
(
tup
,
tuple
):
tup
=
(
tup
,
)
tup
=
(
tup
,)
if
len
(
tup
)
>
length
:
return
None
elif
len
(
tup
)
==
length
:
...
...
@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None):
else
:
return
tup
+
(
pad_val
,)
*
(
length
-
len
(
tup
))
def
reduce_dict_data
(
frames
,
reducer
,
order
=
None
):
"""Merge tensor dictionaries into one. Resolve conflict fields using reducer.
...
...
@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None):
dict[str, Tensor]
Merged frame
"""
if
len
(
frames
)
==
1
and
reducer
!=
'
stack
'
:
if
len
(
frames
)
==
1
and
reducer
!=
"
stack
"
:
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
return
frames
[
0
]
if
callable
(
reducer
):
merger
=
reducer
elif
reducer
==
'
stack
'
:
elif
reducer
==
"
stack
"
:
# Stack order does not matter. However, it must be consistent!
if
order
:
assert
len
(
order
)
==
len
(
frames
)
sorted_with_key
=
sorted
(
zip
(
frames
,
order
),
key
=
lambda
x
:
x
[
1
])
frames
=
list
(
zip
(
*
sorted_with_key
))[
0
]
def
merger
(
flist
):
return
F
.
stack
(
flist
,
1
)
else
:
redfn
=
getattr
(
F
,
reducer
,
None
)
if
redfn
is
None
:
raise
DGLError
(
'Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".'
)
raise
DGLError
(
"Invalid cross type reducer. Must be one of "
'"sum", "max", "min", "mean" or "stack".'
)
def
merger
(
flist
):
return
redfn
(
F
.
stack
(
flist
,
0
),
0
)
if
len
(
flist
)
>
1
else
flist
[
0
]
keys
=
set
()
for
frm
in
frames
:
keys
.
update
(
frm
.
keys
())
...
...
@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None):
ret
[
k
]
=
merger
(
flist
)
return
ret
def
combine_frames
(
frames
,
ids
,
col_names
=
None
):
"""Merge the frames into one frame, taking the common columns.
...
...
@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None):
for
key
,
scheme
in
list
(
schemes
.
items
()):
if
key
in
frame
.
schemes
:
if
frame
.
schemes
[
key
]
!=
scheme
:
raise
DGLError
(
'Cannot concatenate column %s with shape %s and shape %s'
%
(
key
,
frame
.
schemes
[
key
],
scheme
))
raise
DGLError
(
"Cannot concatenate column %s with shape %s and shape %s"
%
(
key
,
frame
.
schemes
[
key
],
scheme
)
)
else
:
del
schemes
[
key
]
...
...
@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None):
cols
=
{
key
:
F
.
cat
(
to_cat
(
key
),
dim
=
0
)
for
key
in
schemes
}
return
Frame
(
cols
)
def
combine_names
(
names
,
ids
=
None
):
"""Combine the selected names into one new name.
...
...
@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None):
str
"""
if
ids
is
None
:
return
'+'
.
join
(
sorted
(
names
))
return
"+"
.
join
(
sorted
(
names
))
else
:
selected
=
sorted
([
names
[
i
]
for
i
in
ids
])
return
'+'
.
join
(
selected
)
return
"+"
.
join
(
selected
)
class
DGLBlock
(
DGLGraph
):
"""Subclass that signifies the graph is a block created from
:func:`dgl.to_block`.
"""
# (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag.
is_block
=
True
def
__repr__
(
self
):
if
len
(
self
.
srctypes
)
==
1
and
len
(
self
.
dsttypes
)
==
1
and
len
(
self
.
etypes
)
==
1
:
ret
=
'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})'
if
(
len
(
self
.
srctypes
)
==
1
and
len
(
self
.
dsttypes
)
==
1
and
len
(
self
.
etypes
)
==
1
):
ret
=
"Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})"
return
ret
.
format
(
srcnode
=
self
.
number_of_src_nodes
(),
dstnode
=
self
.
number_of_dst_nodes
(),
edge
=
self
.
number_of_edges
())
edge
=
self
.
number_of_edges
(),
)
else
:
ret
=
(
'Block(num_src_nodes={srcnode},
\n
'
' num_dst_nodes={dstnode},
\n
'
' num_edges={edge},
\n
'
' metagraph={meta})'
)
nsrcnode_dict
=
{
ntype
:
self
.
number_of_src_nodes
(
ntype
)
for
ntype
in
self
.
srctypes
}
ndstnode_dict
=
{
ntype
:
self
.
number_of_dst_nodes
(
ntype
)
for
ntype
in
self
.
dsttypes
}
nedge_dict
=
{
etype
:
self
.
number_of_edges
(
etype
)
for
etype
in
self
.
canonical_etypes
}
ret
=
(
"Block(num_src_nodes={srcnode},
\n
"
" num_dst_nodes={dstnode},
\n
"
" num_edges={edge},
\n
"
" metagraph={meta})"
)
nsrcnode_dict
=
{
ntype
:
self
.
number_of_src_nodes
(
ntype
)
for
ntype
in
self
.
srctypes
}
ndstnode_dict
=
{
ntype
:
self
.
number_of_dst_nodes
(
ntype
)
for
ntype
in
self
.
dsttypes
}
nedge_dict
=
{
etype
:
self
.
number_of_edges
(
etype
)
for
etype
in
self
.
canonical_etypes
}
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
return
ret
.
format
(
srcnode
=
nsrcnode_dict
,
dstnode
=
ndstnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
srcnode
=
nsrcnode_dict
,
dstnode
=
ndstnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
,
)
def
_create_compute_graph
(
graph
,
u
,
v
,
eid
,
recv_nodes
=
None
):
...
...
@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
srctype
,
etype
,
dsttype
=
graph
.
canonical_etypes
[
0
]
# create graph
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
2
,
len
(
unique_src
),
len
(
unique_dst
),
new_u
,
new_v
,
[
'coo'
,
'csr'
,
'csc'
])
2
,
len
(
unique_src
),
len
(
unique_dst
),
new_u
,
new_v
,
[
"coo"
,
"csr"
,
"csc"
]
)
# create frame
srcframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
srctype
)].
subframe
(
unique_src
)
srcframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
srctype
)].
subframe
(
unique_src
)
srcframe
[
NID
]
=
unique_src
dstframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
dsttype
)].
subframe
(
unique_dst
)
dstframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
dsttype
)].
subframe
(
unique_dst
)
dstframe
[
NID
]
=
unique_dst
eframe
=
graph
.
_edge_frames
[
0
].
subframe
(
eid
)
eframe
[
EID
]
=
eid
return
DGLGraph
(
hgidx
,
([
srctype
],
[
dsttype
]),
[
etype
],
node_frames
=
[
srcframe
,
dstframe
],
edge_frames
=
[
eframe
]),
unique_src
,
unique_dst
,
eid
return
(
DGLGraph
(
hgidx
,
([
srctype
],
[
dsttype
]),
[
etype
],
node_frames
=
[
srcframe
,
dstframe
],
edge_frames
=
[
eframe
],
),
unique_src
,
unique_dst
,
eid
,
)
_init_api
(
"dgl.heterograph"
)
python/dgl/heterograph_index.py
View file @
d1827488
...
...
@@ -7,12 +7,11 @@ import sys
import
numpy
as
np
import
scipy
from
.
import
backend
as
F
from
.
import
utils
from
.
import
backend
as
F
,
utils
from
._ffi.function
import
_init_api
from
._ffi.object
import
ObjectBase
,
register_object
from
._ffi.streams
import
to_dgl_stream_handle
from
.base
import
DGLError
,
dgl_warning
from
.base
import
dgl_warning
,
DGLError
from
.graph_index
import
from_coo
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment