Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d1827488
Unverified
Commit
d1827488
authored
Feb 20, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 20, 2023
Browse files
autofix2 (#5333)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
3e5137fe
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1129 additions
and
560 deletions
+1129
-560
python/dgl/dataloading/base.py
python/dgl/dataloading/base.py
+125
-47
python/dgl/dataloading/dist_dataloader.py
python/dgl/dataloading/dist_dataloader.py
+107
-53
python/dgl/dataloading/labor_sampler.py
python/dgl/dataloading/labor_sampler.py
+4
-6
python/dgl/dataloading/neighbor_sampler.py
python/dgl/dataloading/neighbor_sampler.py
+36
-14
python/dgl/dataloading/shadow.py
python/dgl/dataloading/shadow.py
+26
-8
python/dgl/function/message.py
python/dgl/function/message.py
+20
-10
python/dgl/generators.py
python/dgl/generators.py
+1
-2
python/dgl/geometry/capi.py
python/dgl/geometry/capi.py
+1
-2
python/dgl/graph_index.py
python/dgl/graph_index.py
+2
-3
python/dgl/heterograph.py
python/dgl/heterograph.py
+805
-412
python/dgl/heterograph_index.py
python/dgl/heterograph_index.py
+2
-3
No files found.
python/dgl/dataloading/base.py
View file @
d1827488
"""Base classes and functionalities for dataloaders"""
"""Base classes and functionalities for dataloaders"""
from
collections.abc
import
Mapping
import
inspect
import
inspect
from
..base
import
NID
,
EID
from
collections.abc
import
Mapping
from
..convert
import
heterograph
from
..
import
backend
as
F
from
..
import
backend
as
F
from
..transforms
import
compact_graphs
from
..base
import
EID
,
NID
from
..convert
import
heterograph
from
..frame
import
LazyFeature
from
..frame
import
LazyFeature
from
..utils
import
recursive_apply
,
context_of
from
..transforms
import
compact_graphs
from
..utils
import
context_of
,
recursive_apply
def
_set_lazy_features
(
x
,
xdata
,
feature_names
):
def
_set_lazy_features
(
x
,
xdata
,
feature_names
):
if
feature_names
is
None
:
if
feature_names
is
None
:
...
@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
...
@@ -17,6 +19,7 @@ def _set_lazy_features(x, xdata, feature_names):
for
type_
,
names
in
feature_names
.
items
():
for
type_
,
names
in
feature_names
.
items
():
x
[
type_
].
data
.
update
({
k
:
LazyFeature
(
k
)
for
k
in
names
})
x
[
type_
].
data
.
update
({
k
:
LazyFeature
(
k
)
for
k
in
names
})
def
set_node_lazy_features
(
g
,
feature_names
):
def
set_node_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
"""Assign lazy features to the ``ndata`` of the input graph for prefetching optimization.
...
@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
...
@@ -51,6 +54,7 @@ def set_node_lazy_features(g, feature_names):
"""
"""
return
_set_lazy_features
(
g
.
nodes
,
g
.
ndata
,
feature_names
)
return
_set_lazy_features
(
g
.
nodes
,
g
.
ndata
,
feature_names
)
def
set_edge_lazy_features
(
g
,
feature_names
):
def
set_edge_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
"""Assign lazy features to the ``edata`` of the input graph for prefetching optimization.
...
@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
...
@@ -86,6 +90,7 @@ def set_edge_lazy_features(g, feature_names):
"""
"""
return
_set_lazy_features
(
g
.
edges
,
g
.
edata
,
feature_names
)
return
_set_lazy_features
(
g
.
edges
,
g
.
edata
,
feature_names
)
def
set_src_lazy_features
(
g
,
feature_names
):
def
set_src_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
"""Assign lazy features to the ``srcdata`` of the input graph for prefetching optimization.
...
@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
...
@@ -120,6 +125,7 @@ def set_src_lazy_features(g, feature_names):
"""
"""
return
_set_lazy_features
(
g
.
srcnodes
,
g
.
srcdata
,
feature_names
)
return
_set_lazy_features
(
g
.
srcnodes
,
g
.
srcdata
,
feature_names
)
def
set_dst_lazy_features
(
g
,
feature_names
):
def
set_dst_lazy_features
(
g
,
feature_names
):
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
"""Assign lazy features to the ``dstdata`` of the input graph for prefetching optimization.
...
@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
...
@@ -154,6 +160,7 @@ def set_dst_lazy_features(g, feature_names):
"""
"""
return
_set_lazy_features
(
g
.
dstnodes
,
g
.
dstdata
,
feature_names
)
return
_set_lazy_features
(
g
.
dstnodes
,
g
.
dstdata
,
feature_names
)
class
Sampler
(
object
):
class
Sampler
(
object
):
"""Base class for graph samplers.
"""Base class for graph samplers.
...
@@ -171,6 +178,7 @@ class Sampler(object):
...
@@ -171,6 +178,7 @@ class Sampler(object):
def sample(self, g, indices):
def sample(self, g, indices):
return g.subgraph(indices)
return g.subgraph(indices)
"""
"""
def
sample
(
self
,
g
,
indices
):
def
sample
(
self
,
g
,
indices
):
"""Abstract sample method.
"""Abstract sample method.
...
@@ -183,6 +191,7 @@ class Sampler(object):
...
@@ -183,6 +191,7 @@ class Sampler(object):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
class
BlockSampler
(
Sampler
):
class
BlockSampler
(
Sampler
):
"""Base class for sampling mini-batches in the form of Message-passing
"""Base class for sampling mini-batches in the form of Message-passing
Flow Graphs (MFGs).
Flow Graphs (MFGs).
...
@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
...
@@ -211,8 +220,14 @@ class BlockSampler(Sampler):
The device of the output subgraphs or MFGs. Default is the same as the
The device of the output subgraphs or MFGs. Default is the same as the
minibatch of seed nodes.
minibatch of seed nodes.
"""
"""
def
__init__
(
self
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
):
def
__init__
(
self
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
prefetch_node_feats
=
prefetch_node_feats
or
[]
self
.
prefetch_node_feats
=
prefetch_node_feats
or
[]
self
.
prefetch_labels
=
prefetch_labels
or
[]
self
.
prefetch_labels
=
prefetch_labels
or
[]
...
@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
...
@@ -238,7 +253,9 @@ class BlockSampler(Sampler):
set_edge_lazy_features
(
block
,
self
.
prefetch_edge_feats
)
set_edge_lazy_features
(
block
,
self
.
prefetch_edge_feats
)
return
input_nodes
,
output_nodes
,
blocks
return
input_nodes
,
output_nodes
,
blocks
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
"""Sample a list of blocks from the given seed nodes."""
"""Sample a list of blocks from the given seed nodes."""
result
=
self
.
sample_blocks
(
g
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
result
=
self
.
sample_blocks
(
g
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
self
.
assign_lazy_features
(
result
)
return
self
.
assign_lazy_features
(
result
)
...
@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
...
@@ -249,39 +266,57 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
exclude_eids
=
{
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
for
k
,
v
in
eids
.
items
()}
for
k
,
v
in
eids
.
items
()
}
else
:
else
:
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
return
exclude_eids
return
exclude_eids
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
reverse_etype_map
=
{
reverse_etype_map
=
{
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
for
k
,
v
in
reverse_etype_map
.
items
()}
for
k
,
v
in
reverse_etype_map
.
items
()
exclude_eids
.
update
({
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()})
}
exclude_eids
.
update
(
{
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()}
)
return
exclude_eids
return
exclude_eids
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
if
exclude_mode
is
None
:
if
exclude_mode
is
None
:
return
None
return
None
elif
callable
(
exclude_mode
):
elif
callable
(
exclude_mode
):
return
exclude_mode
(
eids
)
return
exclude_mode
(
eids
)
elif
F
.
is_tensor
(
exclude_mode
)
or
(
elif
F
.
is_tensor
(
exclude_mode
)
or
(
isinstance
(
exclude_mode
,
Mapping
)
and
isinstance
(
exclude_mode
,
Mapping
)
all
(
F
.
is_tensor
(
v
)
for
v
in
exclude_mode
.
values
())):
and
all
(
F
.
is_tensor
(
v
)
for
v
in
exclude_mode
.
values
())
):
return
exclude_mode
return
exclude_mode
elif
exclude_mode
==
'
self
'
:
elif
exclude_mode
==
"
self
"
:
return
eids
return
eids
elif
exclude_mode
==
'reverse_id'
:
elif
exclude_mode
==
"reverse_id"
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
'reverse_eid_map'
])
return
_find_exclude_eids_with_reverse_id
(
elif
exclude_mode
==
'reverse_types'
:
g
,
eids
,
kwargs
[
"reverse_eid_map"
]
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
'reverse_etype_map'
])
)
elif
exclude_mode
==
"reverse_types"
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
"reverse_etype_map"
]
)
else
:
else
:
raise
ValueError
(
'
unsupported mode {}
'
.
format
(
exclude_mode
))
raise
ValueError
(
"
unsupported mode {}
"
.
format
(
exclude_mode
))
def
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
output_device
=
None
):
def
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
output_device
=
None
,
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
Parameters
Parameters
...
@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
...
@@ -334,11 +369,15 @@ def find_exclude_eids(g, seed_edges, exclude, reverse_eids=None, reverse_etypes=
exclude
,
exclude
,
seed_edges
,
seed_edges
,
reverse_eid_map
=
reverse_eids
,
reverse_eid_map
=
reverse_eids
,
reverse_etype_map
=
reverse_etypes
)
reverse_etype_map
=
reverse_etypes
,
)
if
exclude_eids
is
not
None
and
output_device
is
not
None
:
if
exclude_eids
is
not
None
and
output_device
is
not
None
:
exclude_eids
=
recursive_apply
(
exclude_eids
,
lambda
x
:
F
.
copy_to
(
x
,
output_device
))
exclude_eids
=
recursive_apply
(
exclude_eids
,
lambda
x
:
F
.
copy_to
(
x
,
output_device
)
)
return
exclude_eids
return
exclude_eids
class
EdgePredictionSampler
(
Sampler
):
class
EdgePredictionSampler
(
Sampler
):
"""Sampler class that wraps an existing sampler for node classification into another
"""Sampler class that wraps an existing sampler for node classification into another
one for edge classification or link prediction.
one for edge classification or link prediction.
...
@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
...
@@ -347,15 +386,24 @@ class EdgePredictionSampler(Sampler):
--------
--------
as_edge_prediction_sampler
as_edge_prediction_sampler
"""
"""
def
__init__
(
self
,
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
):
def
__init__
(
self
,
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
,
):
super
().
__init__
()
super
().
__init__
()
# Check if the sampler's sample method has an optional third argument.
# Check if the sampler's sample method has an optional third argument.
argspec
=
inspect
.
getfullargspec
(
sampler
.
sample
)
argspec
=
inspect
.
getfullargspec
(
sampler
.
sample
)
if
len
(
argspec
.
args
)
<
4
:
# ['self', 'g', 'indices', 'exclude_eids']
if
len
(
argspec
.
args
)
<
4
:
# ['self', 'g', 'indices', 'exclude_eids']
raise
TypeError
(
raise
TypeError
(
"This sampler does not support edge or link prediction; please add an"
"This sampler does not support edge or link prediction; please add an"
"optional third argument for edge IDs to exclude in its sample() method."
)
"optional third argument for edge IDs to exclude in its sample() method."
)
self
.
reverse_eids
=
reverse_eids
self
.
reverse_eids
=
reverse_eids
self
.
reverse_etypes
=
reverse_etypes
self
.
reverse_etypes
=
reverse_etypes
self
.
exclude
=
exclude
self
.
exclude
=
exclude
...
@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
...
@@ -367,20 +415,27 @@ class EdgePredictionSampler(Sampler):
def
_build_neg_graph
(
self
,
g
,
seed_edges
):
def
_build_neg_graph
(
self
,
g
,
seed_edges
):
neg_srcdst
=
self
.
negative_sampler
(
g
,
seed_edges
)
neg_srcdst
=
self
.
negative_sampler
(
g
,
seed_edges
)
if
not
isinstance
(
neg_srcdst
,
Mapping
):
if
not
isinstance
(
neg_srcdst
,
Mapping
):
assert
len
(
g
.
canonical_etypes
)
==
1
,
\
assert
len
(
g
.
canonical_etypes
)
==
1
,
(
'graph has multiple or no edge types; '
\
"graph has multiple or no edge types; "
'please return a dict in negative sampler.'
"please return a dict in negative sampler."
)
neg_srcdst
=
{
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
neg_srcdst
=
{
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
ctx
=
context_of
(
seed_edges
)
if
seed_edges
is
not
None
else
g
.
device
ctx
=
context_of
(
seed_edges
)
if
seed_edges
is
not
None
else
g
.
device
neg_edges
=
{
neg_edges
=
{
etype
:
neg_srcdst
.
get
(
etype
,
etype
:
neg_srcdst
.
get
(
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
etype
,
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
)))
(
for
etype
in
g
.
canonical_etypes
}
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
=
ctx
),
),
)
for
etype
in
g
.
canonical_etypes
}
neg_pair_graph
=
heterograph
(
neg_pair_graph
=
heterograph
(
neg_edges
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
})
neg_edges
,
{
ntype
:
g
.
num_nodes
(
ntype
)
for
ntype
in
g
.
ntypes
}
)
return
neg_pair_graph
return
neg_pair_graph
def
assign_lazy_features
(
self
,
result
):
def
assign_lazy_features
(
self
,
result
):
...
@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
...
@@ -398,10 +453,13 @@ class EdgePredictionSampler(Sampler):
negative pairs as edges.
negative pairs as edges.
"""
"""
if
isinstance
(
seed_edges
,
Mapping
):
if
isinstance
(
seed_edges
,
Mapping
):
seed_edges
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
seed_edges
.
items
()}
seed_edges
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
seed_edges
.
items
()
}
exclude
=
self
.
exclude
exclude
=
self
.
exclude
pair_graph
=
g
.
edge_subgraph
(
pair_graph
=
g
.
edge_subgraph
(
seed_edges
,
relabel_nodes
=
False
,
output_device
=
self
.
output_device
)
seed_edges
,
relabel_nodes
=
False
,
output_device
=
self
.
output_device
)
eids
=
pair_graph
.
edata
[
EID
]
eids
=
pair_graph
.
edata
[
EID
]
if
self
.
negative_sampler
is
not
None
:
if
self
.
negative_sampler
is
not
None
:
...
@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
...
@@ -414,19 +472,34 @@ class EdgePredictionSampler(Sampler):
seed_nodes
=
pair_graph
.
ndata
[
NID
]
seed_nodes
=
pair_graph
.
ndata
[
NID
]
exclude_eids
=
find_exclude_eids
(
exclude_eids
=
find_exclude_eids
(
g
,
seed_edges
,
exclude
,
self
.
reverse_eids
,
self
.
reverse_etypes
,
g
,
self
.
output_device
)
seed_edges
,
exclude
,
self
.
reverse_eids
,
self
.
reverse_etypes
,
self
.
output_device
,
)
input_nodes
,
_
,
blocks
=
self
.
sampler
.
sample
(
g
,
seed_nodes
,
exclude_eids
)
input_nodes
,
_
,
blocks
=
self
.
sampler
.
sample
(
g
,
seed_nodes
,
exclude_eids
)
if
self
.
negative_sampler
is
None
:
if
self
.
negative_sampler
is
None
:
return
self
.
assign_lazy_features
((
input_nodes
,
pair_graph
,
blocks
))
return
self
.
assign_lazy_features
((
input_nodes
,
pair_graph
,
blocks
))
else
:
else
:
return
self
.
assign_lazy_features
((
input_nodes
,
pair_graph
,
neg_graph
,
blocks
))
return
self
.
assign_lazy_features
(
(
input_nodes
,
pair_graph
,
neg_graph
,
blocks
)
)
def
as_edge_prediction_sampler
(
def
as_edge_prediction_sampler
(
sampler
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
sampler
,
prefetch_labels
=
None
):
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
prefetch_labels
=
None
,
):
"""Create an edge-wise sampler from a node-wise sampler.
"""Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to
For each batch of edges, the sampler applies the provided node-wise sampler to
...
@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
...
@@ -571,5 +644,10 @@ def as_edge_prediction_sampler(
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
"""
"""
return
EdgePredictionSampler
(
return
EdgePredictionSampler
(
sampler
,
exclude
=
exclude
,
reverse_eids
=
reverse_eids
,
reverse_etypes
=
reverse_etypes
,
sampler
,
negative_sampler
=
negative_sampler
,
prefetch_labels
=
prefetch_labels
)
exclude
=
exclude
,
reverse_eids
=
reverse_eids
,
reverse_etypes
=
reverse_etypes
,
negative_sampler
=
negative_sampler
,
prefetch_labels
=
prefetch_labels
,
)
python/dgl/dataloading/dist_dataloader.py
View file @
d1827488
"""Distributed dataloaders.
"""Distributed dataloaders.
"""
"""
import
inspect
import
inspect
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
abc
import
ABC
,
abstractproperty
,
abstractmethod
from
..
import
transforms
from
..
import
backend
as
F
,
transforms
,
utils
from
..base
import
NID
,
EID
from
..base
import
EID
,
NID
from
..
import
backend
as
F
from
..
import
utils
from
..convert
import
heterograph
from
..convert
import
heterograph
from
..distributed
import
DistDataLoader
from
..distributed
import
DistDataLoader
...
@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
...
@@ -20,19 +19,25 @@ def _find_exclude_eids_with_reverse_id(g, eids, reverse_eid_map):
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
exclude_eids
=
{
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
k
:
F
.
cat
([
v
,
F
.
gather_row
(
reverse_eid_map
[
k
],
v
)],
0
)
for
k
,
v
in
eids
.
items
()}
for
k
,
v
in
eids
.
items
()
}
else
:
else
:
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
exclude_eids
=
F
.
cat
([
eids
,
F
.
gather_row
(
reverse_eid_map
,
eids
)],
0
)
return
exclude_eids
return
exclude_eids
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
def
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
reverse_etype_map
):
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
exclude_eids
=
{
g
.
to_canonical_etype
(
k
):
v
for
k
,
v
in
eids
.
items
()}
reverse_etype_map
=
{
reverse_etype_map
=
{
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
g
.
to_canonical_etype
(
k
):
g
.
to_canonical_etype
(
v
)
for
k
,
v
in
reverse_etype_map
.
items
()}
for
k
,
v
in
reverse_etype_map
.
items
()
exclude_eids
.
update
({
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()})
}
exclude_eids
.
update
(
{
reverse_etype_map
[
k
]:
v
for
k
,
v
in
exclude_eids
.
items
()}
)
return
exclude_eids
return
exclude_eids
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
def
_find_exclude_eids
(
g
,
exclude_mode
,
eids
,
**
kwargs
):
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
"""Find all edge IDs to exclude according to :attr:`exclude_mode`.
...
@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
...
@@ -77,14 +82,18 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
"""
"""
if
exclude_mode
is
None
:
if
exclude_mode
is
None
:
return
None
return
None
elif
exclude_mode
==
'
self
'
:
elif
exclude_mode
==
"
self
"
:
return
eids
return
eids
elif
exclude_mode
==
'reverse_id'
:
elif
exclude_mode
==
"reverse_id"
:
return
_find_exclude_eids_with_reverse_id
(
g
,
eids
,
kwargs
[
'reverse_eid_map'
])
return
_find_exclude_eids_with_reverse_id
(
elif
exclude_mode
==
'reverse_types'
:
g
,
eids
,
kwargs
[
"reverse_eid_map"
]
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
'reverse_etype_map'
])
)
elif
exclude_mode
==
"reverse_types"
:
return
_find_exclude_eids_with_reverse_types
(
g
,
eids
,
kwargs
[
"reverse_etype_map"
]
)
else
:
else
:
raise
ValueError
(
'
unsupported mode {}
'
.
format
(
exclude_mode
))
raise
ValueError
(
"
unsupported mode {}
"
.
format
(
exclude_mode
))
class
Collator
(
ABC
):
class
Collator
(
ABC
):
...
@@ -100,6 +109,7 @@ class Collator(ABC):
...
@@ -100,6 +109,7 @@ class Collator(ABC):
:ref:`User Guide Section 6 <guide-minibatch>` and
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
"""
@
abstractproperty
@
abstractproperty
def
dataset
(
self
):
def
dataset
(
self
):
"""Returns the dataset object of the collator."""
"""Returns the dataset object of the collator."""
...
@@ -122,6 +132,7 @@ class Collator(ABC):
...
@@ -122,6 +132,7 @@ class Collator(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
class
NodeCollator
(
Collator
):
class
NodeCollator
(
Collator
):
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
"""DGL collator to combine nodes and their computation dependencies within a minibatch for
training node classification or regression on a single graph with neighborhood sampling.
training node classification or regression on a single graph with neighborhood sampling.
...
@@ -155,14 +166,16 @@ class NodeCollator(Collator):
...
@@ -155,14 +166,16 @@ class NodeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
"""
def
__init__
(
self
,
g
,
nids
,
graph_sampler
):
def
__init__
(
self
,
g
,
nids
,
graph_sampler
):
self
.
g
=
g
self
.
g
=
g
if
not
isinstance
(
nids
,
Mapping
):
if
not
isinstance
(
nids
,
Mapping
):
assert
len
(
g
.
ntypes
)
==
1
,
\
assert
(
"nids should be a dict of node type and ids for graph with multiple node types"
len
(
g
.
ntypes
)
==
1
),
"nids should be a dict of node type and ids for graph with multiple node types"
self
.
graph_sampler
=
graph_sampler
self
.
graph_sampler
=
graph_sampler
self
.
nids
=
utils
.
prepare_tensor_or_dict
(
g
,
nids
,
'
nids
'
)
self
.
nids
=
utils
.
prepare_tensor_or_dict
(
g
,
nids
,
"
nids
"
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
nids
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
nids
)
@
property
@
property
...
@@ -197,12 +210,15 @@ class NodeCollator(Collator):
...
@@ -197,12 +210,15 @@ class NodeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g
,
items
,
"
items
"
)
input_nodes
,
output_nodes
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g
,
items
)
input_nodes
,
output_nodes
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g
,
items
)
return
input_nodes
,
output_nodes
,
blocks
return
input_nodes
,
output_nodes
,
blocks
class
EdgeCollator
(
Collator
):
class
EdgeCollator
(
Collator
):
"""DGL collator to combine edges and their computation dependencies within a minibatch for
"""DGL collator to combine edges and their computation dependencies within a minibatch for
training edge classification, edge regression, or link prediction on a single graph
training edge classification, edge regression, or link prediction on a single graph
...
@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
...
@@ -380,12 +396,23 @@ class EdgeCollator(Collator):
:ref:`User Guide Section 6 <guide-minibatch>` and
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
"""
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
g_sampling
=
None
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
):
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
g_sampling
=
None
,
exclude
=
None
,
reverse_eids
=
None
,
reverse_etypes
=
None
,
negative_sampler
=
None
,
):
self
.
g
=
g
self
.
g
=
g
if
not
isinstance
(
eids
,
Mapping
):
if
not
isinstance
(
eids
,
Mapping
):
assert
len
(
g
.
etypes
)
==
1
,
\
assert
(
"eids should be a dict of etype and ids for graph with multiple etypes"
len
(
g
.
etypes
)
==
1
),
"eids should be a dict of etype and ids for graph with multiple etypes"
self
.
graph_sampler
=
graph_sampler
self
.
graph_sampler
=
graph_sampler
# One may wish to iterate over the edges in one graph while perform sampling in
# One may wish to iterate over the edges in one graph while perform sampling in
...
@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
...
@@ -404,7 +431,7 @@ class EdgeCollator(Collator):
self
.
reverse_etypes
=
reverse_etypes
self
.
reverse_etypes
=
reverse_etypes
self
.
negative_sampler
=
negative_sampler
self
.
negative_sampler
=
negative_sampler
self
.
eids
=
utils
.
prepare_tensor_or_dict
(
g
,
eids
,
'
eids
'
)
self
.
eids
=
utils
.
prepare_tensor_or_dict
(
g
,
eids
,
"
eids
"
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
eids
)
self
.
_dataset
=
utils
.
maybe_flatten_dict
(
self
.
eids
)
@
property
@
property
...
@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
...
@@ -415,7 +442,7 @@ class EdgeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
"
items
"
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
)
seed_nodes
=
pair_graph
.
ndata
[
NID
]
seed_nodes
=
pair_graph
.
ndata
[
NID
]
...
@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
...
@@ -425,10 +452,12 @@ class EdgeCollator(Collator):
self
.
exclude
,
self
.
exclude
,
items
,
items
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_etype_map
=
self
.
reverse_etypes
)
reverse_etype_map
=
self
.
reverse_etypes
,
)
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
input_nodes
,
pair_graph
,
blocks
return
input_nodes
,
pair_graph
,
blocks
...
@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
...
@@ -436,28 +465,39 @@ class EdgeCollator(Collator):
if
isinstance
(
items
[
0
],
tuple
):
if
isinstance
(
items
[
0
],
tuple
):
# returns a list of pairs: group them by node types into a dict
# returns a list of pairs: group them by node types into a dict
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
group_as_dict
(
items
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
'
items
'
)
items
=
utils
.
prepare_tensor_or_dict
(
self
.
g_sampling
,
items
,
"
items
"
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
,
relabel_nodes
=
False
)
pair_graph
=
self
.
g
.
edge_subgraph
(
items
,
relabel_nodes
=
False
)
induced_edges
=
pair_graph
.
edata
[
EID
]
induced_edges
=
pair_graph
.
edata
[
EID
]
neg_srcdst
=
self
.
negative_sampler
(
self
.
g
,
items
)
neg_srcdst
=
self
.
negative_sampler
(
self
.
g
,
items
)
if
not
isinstance
(
neg_srcdst
,
Mapping
):
if
not
isinstance
(
neg_srcdst
,
Mapping
):
assert
len
(
self
.
g
.
etypes
)
==
1
,
\
assert
len
(
self
.
g
.
etypes
)
==
1
,
(
'graph has multiple or no edge types; '
\
"graph has multiple or no edge types; "
'please return a dict in negative sampler.'
"please return a dict in negative sampler."
)
neg_srcdst
=
{
self
.
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
neg_srcdst
=
{
self
.
g
.
canonical_etypes
[
0
]:
neg_srcdst
}
# Get dtype from a tuple of tensors
# Get dtype from a tuple of tensors
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
dtype
=
F
.
dtype
(
list
(
neg_srcdst
.
values
())[
0
][
0
])
ctx
=
F
.
context
(
pair_graph
)
ctx
=
F
.
context
(
pair_graph
)
neg_edges
=
{
neg_edges
=
{
etype
:
neg_srcdst
.
get
(
etype
,
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
etype
:
neg_srcdst
.
get
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
)))
etype
,
for
etype
in
self
.
g
.
canonical_etypes
}
(
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
F
.
copy_to
(
F
.
tensor
([],
dtype
),
ctx
),
),
)
for
etype
in
self
.
g
.
canonical_etypes
}
neg_pair_graph
=
heterograph
(
neg_pair_graph
=
heterograph
(
neg_edges
,
{
ntype
:
self
.
g
.
number_of_nodes
(
ntype
)
for
ntype
in
self
.
g
.
ntypes
})
neg_edges
,
{
ntype
:
self
.
g
.
number_of_nodes
(
ntype
)
for
ntype
in
self
.
g
.
ntypes
},
)
pair_graph
,
neg_pair_graph
=
transforms
.
compact_graphs
([
pair_graph
,
neg_pair_graph
])
pair_graph
,
neg_pair_graph
=
transforms
.
compact_graphs
(
[
pair_graph
,
neg_pair_graph
]
)
pair_graph
.
edata
[
EID
]
=
induced_edges
pair_graph
.
edata
[
EID
]
=
induced_edges
seed_nodes
=
pair_graph
.
ndata
[
NID
]
seed_nodes
=
pair_graph
.
ndata
[
NID
]
...
@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
...
@@ -467,10 +507,12 @@ class EdgeCollator(Collator):
self
.
exclude
,
self
.
exclude
,
items
,
items
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_eid_map
=
self
.
reverse_eids
,
reverse_etype_map
=
self
.
reverse_etypes
)
reverse_etype_map
=
self
.
reverse_etypes
,
)
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
input_nodes
,
_
,
blocks
=
self
.
graph_sampler
.
sample_blocks
(
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
self
.
g_sampling
,
seed_nodes
,
exclude_eids
=
exclude_eids
)
return
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
return
input_nodes
,
pair_graph
,
neg_pair_graph
,
blocks
...
@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
...
@@ -517,13 +559,14 @@ class EdgeCollator(Collator):
def
_remove_kwargs_dist
(
kwargs
):
def
_remove_kwargs_dist
(
kwargs
):
if
'
num_workers
'
in
kwargs
:
if
"
num_workers
"
in
kwargs
:
del
kwargs
[
'
num_workers
'
]
del
kwargs
[
"
num_workers
"
]
if
'
pin_memory
'
in
kwargs
:
if
"
pin_memory
"
in
kwargs
:
del
kwargs
[
'
pin_memory
'
]
del
kwargs
[
"
pin_memory
"
]
print
(
'
Distributed DataLoaders do not support pin_memory.
'
)
print
(
"
Distributed DataLoaders do not support pin_memory.
"
)
return
kwargs
return
kwargs
class
DistNodeDataLoader
(
DistDataLoader
):
class
DistNodeDataLoader
(
DistDataLoader
):
"""Sampled graph data loader over nodes for distributed graph storage.
"""Sampled graph data loader over nodes for distributed graph storage.
...
@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
...
@@ -547,6 +590,7 @@ class DistNodeDataLoader(DistDataLoader):
--------
--------
dgl.dataloading.DataLoader
dgl.dataloading.DataLoader
"""
"""
def
__init__
(
self
,
g
,
nids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
def
__init__
(
self
,
g
,
nids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
collator_kwargs
=
{}
collator_kwargs
=
{}
dataloader_kwargs
=
{}
dataloader_kwargs
=
{}
...
@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
...
@@ -558,17 +602,22 @@ class DistNodeDataLoader(DistDataLoader):
dataloader_kwargs
[
k
]
=
v
dataloader_kwargs
[
k
]
=
v
if
device
is
None
:
if
device
is
None
:
# for the distributed case default to the CPU
# for the distributed case default to the CPU
device
=
'cpu'
device
=
"cpu"
assert
device
==
'cpu'
,
'Only cpu is supported in the case of a DistGraph.'
assert
(
device
==
"cpu"
),
"Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
# and does not copy features. Fallback to normal solution
self
.
collator
=
NodeCollator
(
g
,
nids
,
graph_sampler
,
**
collator_kwargs
)
self
.
collator
=
NodeCollator
(
g
,
nids
,
graph_sampler
,
**
collator_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
**
dataloader_kwargs
)
self
.
device
=
device
self
.
device
=
device
class
DistEdgeDataLoader
(
DistDataLoader
):
class
DistEdgeDataLoader
(
DistDataLoader
):
"""Sampled graph data loader over edges for distributed graph storage.
"""Sampled graph data loader over edges for distributed graph storage.
...
@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
...
@@ -593,6 +642,7 @@ class DistEdgeDataLoader(DistDataLoader):
--------
--------
dgl.dataloading.DataLoader
dgl.dataloading.DataLoader
"""
"""
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
def
__init__
(
self
,
g
,
eids
,
graph_sampler
,
device
=
None
,
**
kwargs
):
collator_kwargs
=
{}
collator_kwargs
=
{}
dataloader_kwargs
=
{}
dataloader_kwargs
=
{}
...
@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
...
@@ -605,14 +655,18 @@ class DistEdgeDataLoader(DistDataLoader):
if
device
is
None
:
if
device
is
None
:
# for the distributed case default to the CPU
# for the distributed case default to the CPU
device
=
'cpu'
device
=
"cpu"
assert
device
==
'cpu'
,
'Only cpu is supported in the case of a DistGraph.'
assert
(
device
==
"cpu"
),
"Only cpu is supported in the case of a DistGraph."
# Distributed DataLoader currently does not support heterogeneous graphs
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
# and does not copy features. Fallback to normal solution
self
.
collator
=
EdgeCollator
(
g
,
eids
,
graph_sampler
,
**
collator_kwargs
)
self
.
collator
=
EdgeCollator
(
g
,
eids
,
graph_sampler
,
**
collator_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
_remove_kwargs_dist
(
dataloader_kwargs
)
super
().
__init__
(
self
.
collator
.
dataset
,
super
().
__init__
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
**
dataloader_kwargs
)
self
.
device
=
device
self
.
device
=
device
python/dgl/dataloading/labor_sampler.py
View file @
d1827488
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
#
#
"""Data loading components for labor sampling"""
"""Data loading components for labor sampling"""
from
..base
import
NID
,
EID
from
..
import
backend
as
F
from
..base
import
EID
,
NID
from
..random
import
choice
from
..transforms
import
to_block
from
..transforms
import
to_block
from
.base
import
BlockSampler
from
.base
import
BlockSampler
from
..random
import
choice
from
..
import
backend
as
F
class
LaborSampler
(
BlockSampler
):
class
LaborSampler
(
BlockSampler
):
...
@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
...
@@ -211,9 +211,7 @@ class LaborSampler(BlockSampler):
)
)
block
.
edata
[
EID
]
=
eid
block
.
edata
[
EID
]
=
eid
if
len
(
g
.
canonical_etypes
)
>
1
:
if
len
(
g
.
canonical_etypes
)
>
1
:
for
etype
,
importance
in
zip
(
for
etype
,
importance
in
zip
(
g
.
canonical_etypes
,
importances
):
g
.
canonical_etypes
,
importances
):
if
importance
.
shape
[
0
]
==
block
.
num_edges
(
etype
):
if
importance
.
shape
[
0
]
==
block
.
num_edges
(
etype
):
block
.
edata
[
"edge_weights"
][
etype
]
=
importance
block
.
edata
[
"edge_weights"
][
etype
]
=
importance
elif
importances
[
0
].
shape
[
0
]
==
block
.
num_edges
():
elif
importances
[
0
].
shape
[
0
]
==
block
.
num_edges
():
...
...
python/dgl/dataloading/neighbor_sampler.py
View file @
d1827488
"""Data loading components for neighbor sampling"""
"""Data loading components for neighbor sampling"""
from
..base
import
N
ID
,
E
ID
from
..base
import
E
ID
,
N
ID
from
..transforms
import
to_block
from
..transforms
import
to_block
from
.base
import
BlockSampler
from
.base
import
BlockSampler
class
NeighborSampler
(
BlockSampler
):
class
NeighborSampler
(
BlockSampler
):
"""Sampler that builds computational dependency of node representations via
"""Sampler that builds computational dependency of node representations via
neighbor sampling for multilayer GNN.
neighbor sampling for multilayer GNN.
...
@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
...
@@ -107,20 +108,33 @@ class NeighborSampler(BlockSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
"""
def
__init__
(
self
,
fanouts
,
edge_dir
=
'in'
,
prob
=
None
,
mask
=
None
,
replace
=
False
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
def
__init__
(
output_device
=
None
):
self
,
super
().
__init__
(
prefetch_node_feats
=
prefetch_node_feats
,
fanouts
,
edge_dir
=
"in"
,
prob
=
None
,
mask
=
None
,
replace
=
False
,
prefetch_node_feats
=
None
,
prefetch_labels
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
(
prefetch_node_feats
=
prefetch_node_feats
,
prefetch_labels
=
prefetch_labels
,
prefetch_labels
=
prefetch_labels
,
prefetch_edge_feats
=
prefetch_edge_feats
,
prefetch_edge_feats
=
prefetch_edge_feats
,
output_device
=
output_device
)
output_device
=
output_device
,
)
self
.
fanouts
=
fanouts
self
.
fanouts
=
fanouts
self
.
edge_dir
=
edge_dir
self
.
edge_dir
=
edge_dir
if
mask
is
not
None
and
prob
is
not
None
:
if
mask
is
not
None
and
prob
is
not
None
:
raise
ValueError
(
raise
ValueError
(
'Mask and probability arguments are mutually exclusive. '
"Mask and probability arguments are mutually exclusive. "
'Consider multiplying the probability with the mask '
"Consider multiplying the probability with the mask "
'to achieve the same goal.'
)
"to achieve the same goal."
)
self
.
prob
=
prob
or
mask
self
.
prob
=
prob
or
mask
self
.
replace
=
replace
self
.
replace
=
replace
...
@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
...
@@ -129,9 +143,14 @@ class NeighborSampler(BlockSampler):
blocks
=
[]
blocks
=
[]
for
fanout
in
reversed
(
self
.
fanouts
):
for
fanout
in
reversed
(
self
.
fanouts
):
frontier
=
g
.
sample_neighbors
(
frontier
=
g
.
sample_neighbors
(
seed_nodes
,
fanout
,
edge_dir
=
self
.
edge_dir
,
prob
=
self
.
prob
,
seed_nodes
,
replace
=
self
.
replace
,
output_device
=
self
.
output_device
,
fanout
,
exclude_edges
=
exclude_eids
)
edge_dir
=
self
.
edge_dir
,
prob
=
self
.
prob
,
replace
=
self
.
replace
,
output_device
=
self
.
output_device
,
exclude_edges
=
exclude_eids
,
)
eid
=
frontier
.
edata
[
EID
]
eid
=
frontier
.
edata
[
EID
]
block
=
to_block
(
frontier
,
seed_nodes
)
block
=
to_block
(
frontier
,
seed_nodes
)
block
.
edata
[
EID
]
=
eid
block
.
edata
[
EID
]
=
eid
...
@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
...
@@ -140,8 +159,10 @@ class NeighborSampler(BlockSampler):
return
seed_nodes
,
output_nodes
,
blocks
return
seed_nodes
,
output_nodes
,
blocks
MultiLayerNeighborSampler
=
NeighborSampler
MultiLayerNeighborSampler
=
NeighborSampler
class
MultiLayerFullNeighborSampler
(
NeighborSampler
):
class
MultiLayerFullNeighborSampler
(
NeighborSampler
):
"""Sampler that builds computational dependency of node representations by taking messages
"""Sampler that builds computational dependency of node representations by taking messages
from all neighbors for multilayer GNN.
from all neighbors for multilayer GNN.
...
@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
...
@@ -174,5 +195,6 @@ class MultiLayerFullNeighborSampler(NeighborSampler):
:ref:`User Guide Section 6 <guide-minibatch>` and
:ref:`User Guide Section 6 <guide-minibatch>` and
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
"""
def
__init__
(
self
,
num_layers
,
**
kwargs
):
def
__init__
(
self
,
num_layers
,
**
kwargs
):
super
().
__init__
([
-
1
]
*
num_layers
,
**
kwargs
)
super
().
__init__
([
-
1
]
*
num_layers
,
**
kwargs
)
python/dgl/dataloading/shadow.py
View file @
d1827488
"""ShaDow-GNN subgraph samplers."""
"""ShaDow-GNN subgraph samplers."""
from
..sampling.utils
import
EidExcluder
from
..
import
transforms
from
..
import
transforms
from
..base
import
NID
from
..base
import
NID
from
.base
import
set_node_lazy_features
,
set_edge_lazy_features
,
Sampler
from
..sampling.utils
import
EidExcluder
from
.base
import
Sampler
,
set_edge_lazy_features
,
set_node_lazy_features
class
ShaDowKHopSampler
(
Sampler
):
class
ShaDowKHopSampler
(
Sampler
):
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
"""K-hop subgraph sampler from `Deep Graph Neural Networks with Shallow
...
@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
...
@@ -68,8 +69,16 @@ class ShaDowKHopSampler(Sampler):
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> g.edata['p'] = torch.rand(g.num_edges()) # any non-negative 1D vector works
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
>>> sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15], prob='p')
"""
"""
def
__init__
(
self
,
fanouts
,
replace
=
False
,
prob
=
None
,
prefetch_node_feats
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
):
def
__init__
(
self
,
fanouts
,
replace
=
False
,
prob
=
None
,
prefetch_node_feats
=
None
,
prefetch_edge_feats
=
None
,
output_device
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
fanouts
=
fanouts
self
.
fanouts
=
fanouts
self
.
replace
=
replace
self
.
replace
=
replace
...
@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
...
@@ -78,7 +87,9 @@ class ShaDowKHopSampler(Sampler):
self
.
prefetch_edge_feats
=
prefetch_edge_feats
self
.
prefetch_edge_feats
=
prefetch_edge_feats
self
.
output_device
=
output_device
self
.
output_device
=
output_device
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
def
sample
(
self
,
g
,
seed_nodes
,
exclude_eids
=
None
):
# pylint: disable=arguments-differ
"""Sampling function.
"""Sampling function.
Parameters
Parameters
...
@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
...
@@ -99,12 +110,19 @@ class ShaDowKHopSampler(Sampler):
output_nodes
=
seed_nodes
output_nodes
=
seed_nodes
for
fanout
in
reversed
(
self
.
fanouts
):
for
fanout
in
reversed
(
self
.
fanouts
):
frontier
=
g
.
sample_neighbors
(
frontier
=
g
.
sample_neighbors
(
seed_nodes
,
fanout
,
output_device
=
self
.
output_device
,
seed_nodes
,
replace
=
self
.
replace
,
prob
=
self
.
prob
,
exclude_edges
=
exclude_eids
)
fanout
,
output_device
=
self
.
output_device
,
replace
=
self
.
replace
,
prob
=
self
.
prob
,
exclude_edges
=
exclude_eids
,
)
block
=
transforms
.
to_block
(
frontier
,
seed_nodes
)
block
=
transforms
.
to_block
(
frontier
,
seed_nodes
)
seed_nodes
=
block
.
srcdata
[
NID
]
seed_nodes
=
block
.
srcdata
[
NID
]
subg
=
g
.
subgraph
(
seed_nodes
,
relabel_nodes
=
True
,
output_device
=
self
.
output_device
)
subg
=
g
.
subgraph
(
seed_nodes
,
relabel_nodes
=
True
,
output_device
=
self
.
output_device
)
if
exclude_eids
is
not
None
:
if
exclude_eids
is
not
None
:
subg
=
EidExcluder
(
exclude_eids
)(
subg
)
subg
=
EidExcluder
(
exclude_eids
)(
subg
)
...
...
python/dgl/function/message.py
View file @
d1827488
...
@@ -7,8 +7,7 @@ from itertools import product
...
@@ -7,8 +7,7 @@ from itertools import product
from
.base
import
BuiltinFunction
,
TargetCode
from
.base
import
BuiltinFunction
,
TargetCode
__all__
=
[
"copy_u"
,
"copy_e"
,
__all__
=
[
"copy_u"
,
"copy_e"
,
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
"BinaryMessageFunction"
,
"CopyMessageFunction"
]
class
MessageFunction
(
BuiltinFunction
):
class
MessageFunction
(
BuiltinFunction
):
...
@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
...
@@ -27,6 +26,7 @@ class BinaryMessageFunction(MessageFunction):
--------
--------
u_mul_e
u_mul_e
"""
"""
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
def
__init__
(
self
,
binary_op
,
lhs
,
rhs
,
lhs_field
,
rhs_field
,
out_field
):
self
.
binary_op
=
binary_op
self
.
binary_op
=
binary_op
self
.
lhs
=
lhs
self
.
lhs
=
lhs
...
@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
...
@@ -49,6 +49,7 @@ class CopyMessageFunction(MessageFunction):
--------
--------
copy_u
copy_u
"""
"""
def
__init__
(
self
,
target
,
in_field
,
out_field
):
def
__init__
(
self
,
target
,
in_field
,
out_field
):
self
.
target
=
target
self
.
target
=
target
self
.
in_field
=
in_field
self
.
in_field
=
in_field
...
@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
...
@@ -151,17 +152,25 @@ def _gen_message_builtin(lhs, rhs, binary_op):
--------
--------
>>> import dgl
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
>>> message_func = dgl.function.{}('h', 'h', 'm')
"""
.
format
(
binary_op
,
"""
.
format
(
binary_op
,
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
lhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
TargetCode
.
CODE2STR
[
_TARGET_MAP
[
rhs
]],
name
)
name
,
)
def
func
(
lhs_field
,
rhs_field
,
out
):
def
func
(
lhs_field
,
rhs_field
,
out
):
return
BinaryMessageFunction
(
return
BinaryMessageFunction
(
binary_op
,
_TARGET_MAP
[
lhs
],
binary_op
,
_TARGET_MAP
[
rhs
],
lhs_field
,
rhs_field
,
out
)
_TARGET_MAP
[
lhs
],
_TARGET_MAP
[
rhs
],
lhs_field
,
rhs_field
,
out
,
)
func
.
__name__
=
name
func
.
__name__
=
name
func
.
__doc__
=
docstring
func
.
__doc__
=
docstring
return
func
return
func
...
@@ -177,4 +186,5 @@ def _register_builtin_message_func():
...
@@ -177,4 +186,5 @@ def _register_builtin_message_func():
setattr
(
sys
.
modules
[
__name__
],
func
.
__name__
,
func
)
setattr
(
sys
.
modules
[
__name__
],
func
.
__name__
,
func
)
__all__
.
append
(
func
.
__name__
)
__all__
.
append
(
func
.
__name__
)
_register_builtin_message_func
()
_register_builtin_message_func
()
python/dgl/generators.py
View file @
d1827488
"""Module for various graph generator functions."""
"""Module for various graph generator functions."""
from
.
import
backend
as
F
from
.
import
backend
as
F
,
convert
,
random
from
.
import
convert
,
random
__all__
=
[
"rand_graph"
,
"rand_bipartite"
]
__all__
=
[
"rand_graph"
,
"rand_bipartite"
]
...
...
python/dgl/geometry/capi.py
View file @
d1827488
"""Python interfaces to DGL farthest point sampler."""
"""Python interfaces to DGL farthest point sampler."""
import
numpy
as
np
import
numpy
as
np
from
..
import
backend
as
F
from
..
import
backend
as
F
,
ndarray
as
nd
from
..
import
ndarray
as
nd
from
.._ffi.base
import
DGLError
from
.._ffi.base
import
DGLError
from
.._ffi.function
import
_init_api
from
.._ffi.function
import
_init_api
...
...
python/dgl/graph_index.py
View file @
d1827488
...
@@ -5,11 +5,10 @@ import networkx as nx
...
@@ -5,11 +5,10 @@ import networkx as nx
import
numpy
as
np
import
numpy
as
np
import
scipy
import
scipy
from
.
import
backend
as
F
from
.
import
backend
as
F
,
utils
from
.
import
utils
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
._ffi.object
import
ObjectBase
,
register_object
from
._ffi.object
import
ObjectBase
,
register_object
from
.base
import
DGLError
,
dgl_warning
from
.base
import
dgl_warning
,
DGLError
class
BoolFlag
(
object
):
class
BoolFlag
(
object
):
...
...
python/dgl/heterograph.py
View file @
d1827488
"""Classes for heterogeneous graphs."""
"""Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections.abc
import
Mapping
,
Iterable
from
contextlib
import
contextmanager
import
copy
import
copy
import
numbers
import
itertools
import
itertools
import
numbers
# pylint: disable= too-many-lines
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
contextlib
import
contextmanager
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
numpy
as
np
from
.
import
backend
as
F
,
core
,
graph_index
,
heterograph_index
,
utils
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
.ops
import
segment
from
.base
import
(
from
.base
import
ALL
,
SLICE_FULL
,
NTYPE
,
NID
,
ETYPE
,
EID
,
is_all
,
DGLError
,
dgl_warning
ALL
,
from
.
import
core
dgl_warning
,
from
.
import
graph_index
DGLError
,
from
.
import
heterograph_index
EID
,
from
.
import
utils
ETYPE
,
from
.
import
backend
as
F
is_all
,
NID
,
NTYPE
,
SLICE_FULL
,
)
from
.frame
import
Frame
from
.frame
import
Frame
from
.view
import
HeteroNodeView
,
HeteroNodeDataView
,
HeteroEdgeView
,
HeteroEdgeDataView
from
.ops
import
segment
from
.view
import
(
HeteroEdgeDataView
,
HeteroEdgeView
,
HeteroNodeDataView
,
HeteroNodeView
,
)
__all__
=
[
"DGLGraph"
,
"combine_names"
]
__all__
=
[
'DGLGraph'
,
'combine_names'
]
class
DGLGraph
(
object
):
class
DGLGraph
(
object
):
"""Class for storing graph structure and node/edge feature data.
"""Class for storing graph structure and node/edge feature data.
...
@@ -35,16 +50,19 @@ class DGLGraph(object):
...
@@ -35,16 +50,19 @@ class DGLGraph(object):
Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its
Read the user guide chapter :ref:`guide-graph` for an in-depth explanation about its
usage.
usage.
"""
"""
is_block
=
False
is_block
=
False
# pylint: disable=unused-argument, dangerous-default-value
# pylint: disable=unused-argument, dangerous-default-value
def
__init__
(
self
,
def
__init__
(
self
,
gidx
=
[],
gidx
=
[],
ntypes
=
[
'
_N
'
],
ntypes
=
[
"
_N
"
],
etypes
=
[
'
_E
'
],
etypes
=
[
"
_E
"
],
node_frames
=
None
,
node_frames
=
None
,
edge_frames
=
None
,
edge_frames
=
None
,
**
deprecate_kwargs
):
**
deprecate_kwargs
):
"""Internal constructor for creating a DGLGraph.
"""Internal constructor for creating a DGLGraph.
Parameters
Parameters
...
@@ -67,21 +85,42 @@ class DGLGraph(object):
...
@@ -67,21 +85,42 @@ class DGLGraph(object):
of edge type i. (default: None)
of edge type i. (default: None)
"""
"""
if
isinstance
(
gidx
,
DGLGraph
):
if
isinstance
(
gidx
,
DGLGraph
):
raise
DGLError
(
'The input is already a DGLGraph. No need to create it again.'
)
raise
DGLError
(
"The input is already a DGLGraph. No need to create it again."
)
if
not
isinstance
(
gidx
,
heterograph_index
.
HeteroGraphIndex
):
if
not
isinstance
(
gidx
,
heterograph_index
.
HeteroGraphIndex
):
dgl_warning
(
'Recommend creating graphs by `dgl.graph(data)`'
dgl_warning
(
' instead of `dgl.DGLGraph(data)`.'
)
"Recommend creating graphs by `dgl.graph(data)`"
(
sparse_fmt
,
arrays
),
num_src
,
num_dst
=
utils
.
graphdata2tensors
(
gidx
)
" instead of `dgl.DGLGraph(data)`."
if
sparse_fmt
==
'coo'
:
)
(
sparse_fmt
,
arrays
),
num_src
,
num_dst
=
utils
.
graphdata2tensors
(
gidx
)
if
sparse_fmt
==
"coo"
:
gidx
=
heterograph_index
.
create_unitgraph_from_coo
(
gidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
[
'coo'
,
'csr'
,
'csc'
])
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
[
"coo"
,
"csr"
,
"csc"
],
)
else
:
else
:
gidx
=
heterograph_index
.
create_unitgraph_from_csr
(
gidx
=
heterograph_index
.
create_unitgraph_from_csr
(
1
,
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
arrays
[
2
],
[
'coo'
,
'csr'
,
'csc'
],
1
,
sparse_fmt
==
'csc'
)
num_src
,
num_dst
,
arrays
[
0
],
arrays
[
1
],
arrays
[
2
],
[
"coo"
,
"csr"
,
"csc"
],
sparse_fmt
==
"csc"
,
)
if
len
(
deprecate_kwargs
)
!=
0
:
if
len
(
deprecate_kwargs
)
!=
0
:
dgl_warning
(
'Keyword arguments {} are deprecated in v0.5, and can be safely'
dgl_warning
(
' removed in all cases.'
.
format
(
list
(
deprecate_kwargs
.
keys
())))
"Keyword arguments {} are deprecated in v0.5, and can be safely"
" removed in all cases."
.
format
(
list
(
deprecate_kwargs
.
keys
()))
)
self
.
_init
(
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
)
self
.
_init
(
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
)
def
_init
(
self
,
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
):
def
_init
(
self
,
gidx
,
ntypes
,
etypes
,
node_frames
,
edge_frames
):
...
@@ -94,39 +133,51 @@ class DGLGraph(object):
...
@@ -94,39 +133,51 @@ class DGLGraph(object):
# Handle node types
# Handle node types
if
isinstance
(
ntypes
,
tuple
):
if
isinstance
(
ntypes
,
tuple
):
if
len
(
ntypes
)
!=
2
:
if
len
(
ntypes
)
!=
2
:
errmsg
=
'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'
.
format
(
errmsg
=
"Invalid input. Expect a pair (srctypes, dsttypes) but got {}"
.
format
(
ntypes
)
ntypes
)
raise
TypeError
(
errmsg
)
raise
TypeError
(
errmsg
)
if
not
self
.
_graph
.
is_metagraph_unibipartite
():
if
not
self
.
_graph
.
is_metagraph_unibipartite
():
raise
ValueError
(
'Invalid input. The metagraph must be a uni-directional'
raise
ValueError
(
' bipartite graph.'
)
"Invalid input. The metagraph must be a uni-directional"
" bipartite graph."
)
self
.
_ntypes
=
ntypes
[
0
]
+
ntypes
[
1
]
self
.
_ntypes
=
ntypes
[
0
]
+
ntypes
[
1
]
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
ntypes
[
0
])}
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
ntypes
[
0
])}
self
.
_dsttypes_invmap
=
{
t
:
i
+
len
(
ntypes
[
0
])
for
i
,
t
in
enumerate
(
ntypes
[
1
])}
self
.
_dsttypes_invmap
=
{
t
:
i
+
len
(
ntypes
[
0
])
for
i
,
t
in
enumerate
(
ntypes
[
1
])
}
self
.
_is_unibipartite
=
True
self
.
_is_unibipartite
=
True
if
len
(
ntypes
[
0
])
==
1
and
len
(
ntypes
[
1
])
==
1
and
len
(
etypes
)
==
1
:
if
len
(
ntypes
[
0
])
==
1
and
len
(
ntypes
[
1
])
==
1
and
len
(
etypes
)
==
1
:
self
.
_canonical_etypes
=
[(
ntypes
[
0
][
0
],
etypes
[
0
],
ntypes
[
1
][
0
])]
self
.
_canonical_etypes
=
[
(
ntypes
[
0
][
0
],
etypes
[
0
],
ntypes
[
1
][
0
])
]
else
:
else
:
self
.
_ntypes
=
ntypes
self
.
_ntypes
=
ntypes
if
len
(
ntypes
)
==
1
:
if
len
(
ntypes
)
==
1
:
src_dst_map
=
None
src_dst_map
=
None
else
:
else
:
src_dst_map
=
find_src_dst_ntypes
(
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
src_dst_map
=
find_src_dst_ntypes
(
self
.
_is_unibipartite
=
(
src_dst_map
is
not
None
)
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
self
.
_is_unibipartite
=
src_dst_map
is
not
None
if
self
.
_is_unibipartite
:
if
self
.
_is_unibipartite
:
self
.
_srctypes_invmap
,
self
.
_dsttypes_invmap
=
src_dst_map
self
.
_srctypes_invmap
,
self
.
_dsttypes_invmap
=
src_dst_map
else
:
else
:
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_ntypes
)}
self
.
_srctypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_ntypes
)
}
self
.
_dsttypes_invmap
=
self
.
_srctypes_invmap
self
.
_dsttypes_invmap
=
self
.
_srctypes_invmap
# Handle edge types
# Handle edge types
self
.
_etypes
=
etypes
self
.
_etypes
=
etypes
if
self
.
_canonical_etypes
is
None
:
if
self
.
_canonical_etypes
is
None
:
if
(
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
)
:
if
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
:
self
.
_canonical_etypes
=
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
self
.
_canonical_etypes
=
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
else
:
else
:
self
.
_canonical_etypes
=
make_canonical_etypes
(
self
.
_canonical_etypes
=
make_canonical_etypes
(
self
.
_etypes
,
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
self
.
_etypes
,
self
.
_ntypes
,
self
.
_graph
.
metagraph
)
# An internal map from etype to canonical etype tuple.
# An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicate
# If two etypes have the same name, an empty tuple is stored instead to indicate
...
@@ -137,21 +188,29 @@ class DGLGraph(object):
...
@@ -137,21 +188,29 @@ class DGLGraph(object):
self
.
_etype2canonical
[
ety
]
=
tuple
()
self
.
_etype2canonical
[
ety
]
=
tuple
()
else
:
else
:
self
.
_etype2canonical
[
ety
]
=
self
.
_canonical_etypes
[
i
]
self
.
_etype2canonical
[
ety
]
=
self
.
_canonical_etypes
[
i
]
self
.
_etypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_canonical_etypes
)}
self
.
_etypes_invmap
=
{
t
:
i
for
i
,
t
in
enumerate
(
self
.
_canonical_etypes
)
}
# node and edge frame
# node and edge frame
if
node_frames
is
None
:
if
node_frames
is
None
:
node_frames
=
[
None
]
*
len
(
self
.
_ntypes
)
node_frames
=
[
None
]
*
len
(
self
.
_ntypes
)
node_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_nodes
(
i
))
node_frames
=
[
if
frame
is
None
else
frame
Frame
(
num_rows
=
self
.
_graph
.
number_of_nodes
(
i
))
for
i
,
frame
in
enumerate
(
node_frames
)]
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
node_frames
)
]
self
.
_node_frames
=
node_frames
self
.
_node_frames
=
node_frames
if
edge_frames
is
None
:
if
edge_frames
is
None
:
edge_frames
=
[
None
]
*
len
(
self
.
_etypes
)
edge_frames
=
[
None
]
*
len
(
self
.
_etypes
)
edge_frames
=
[
Frame
(
num_rows
=
self
.
_graph
.
number_of_edges
(
i
))
edge_frames
=
[
if
frame
is
None
else
frame
Frame
(
num_rows
=
self
.
_graph
.
number_of_edges
(
i
))
for
i
,
frame
in
enumerate
(
edge_frames
)]
if
frame
is
None
else
frame
for
i
,
frame
in
enumerate
(
edge_frames
)
]
self
.
_edge_frames
=
edge_frames
self
.
_edge_frames
=
edge_frames
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
...
@@ -162,40 +221,60 @@ class DGLGraph(object):
...
@@ -162,40 +221,60 @@ class DGLGraph(object):
self
.
__dict__
.
update
(
state
)
self
.
__dict__
.
update
(
state
)
elif
isinstance
(
state
,
tuple
)
and
len
(
state
)
==
5
:
elif
isinstance
(
state
,
tuple
)
and
len
(
state
)
==
5
:
# DGL == 0.4.3
# DGL == 0.4.3
dgl_warning
(
"The object is pickled with DGL == 0.4.3. "
dgl_warning
(
"Some of the original attributes are ignored."
)
"The object is pickled with DGL == 0.4.3. "
"Some of the original attributes are ignored."
)
self
.
_init
(
*
state
)
self
.
_init
(
*
state
)
elif
isinstance
(
state
,
dict
):
elif
isinstance
(
state
,
dict
):
# DGL <= 0.4.2
# DGL <= 0.4.2
dgl_warning
(
"The object is pickled with DGL <= 0.4.2. "
dgl_warning
(
"Some of the original attributes are ignored."
)
"The object is pickled with DGL <= 0.4.2. "
self
.
_init
(
state
[
'_graph'
],
state
[
'_ntypes'
],
state
[
'_etypes'
],
state
[
'_node_frames'
],
"Some of the original attributes are ignored."
state
[
'_edge_frames'
])
)
self
.
_init
(
state
[
"_graph"
],
state
[
"_ntypes"
],
state
[
"_etypes"
],
state
[
"_node_frames"
],
state
[
"_edge_frames"
],
)
else
:
else
:
raise
IOError
(
"Unrecognized pickle format."
)
raise
IOError
(
"Unrecognized pickle format."
)
def
__repr__
(
self
):
def
__repr__
(
self
):
if
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
:
if
len
(
self
.
ntypes
)
==
1
and
len
(
self
.
etypes
)
==
1
:
ret
=
(
'Graph(num_nodes={node}, num_edges={edge},
\n
'
ret
=
(
' ndata_schemes={ndata}
\n
'
"Graph(num_nodes={node}, num_edges={edge},
\n
"
' edata_schemes={edata})'
)
" ndata_schemes={ndata}
\n
"
return
ret
.
format
(
node
=
self
.
number_of_nodes
(),
edge
=
self
.
number_of_edges
(),
" edata_schemes={edata})"
)
return
ret
.
format
(
node
=
self
.
number_of_nodes
(),
edge
=
self
.
number_of_edges
(),
ndata
=
str
(
self
.
node_attr_schemes
()),
ndata
=
str
(
self
.
node_attr_schemes
()),
edata
=
str
(
self
.
edge_attr_schemes
()))
edata
=
str
(
self
.
edge_attr_schemes
()),
)
else
:
else
:
ret
=
(
'Graph(num_nodes={node},
\n
'
ret
=
(
' num_edges={edge},
\n
'
"Graph(num_nodes={node},
\n
"
' metagraph={meta})'
)
" num_edges={edge},
\n
"
nnode_dict
=
{
self
.
ntypes
[
i
]
:
self
.
_graph
.
number_of_nodes
(
i
)
" metagraph={meta})"
for
i
in
range
(
len
(
self
.
ntypes
))}
)
nedge_dict
=
{
self
.
canonical_etypes
[
i
]
:
self
.
_graph
.
number_of_edges
(
i
)
nnode_dict
=
{
for
i
in
range
(
len
(
self
.
etypes
))}
self
.
ntypes
[
i
]:
self
.
_graph
.
number_of_nodes
(
i
)
for
i
in
range
(
len
(
self
.
ntypes
))
}
nedge_dict
=
{
self
.
canonical_etypes
[
i
]:
self
.
_graph
.
number_of_edges
(
i
)
for
i
in
range
(
len
(
self
.
etypes
))
}
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
return
ret
.
format
(
node
=
nnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
return
ret
.
format
(
node
=
nnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
def
__copy__
(
self
):
def
__copy__
(
self
):
"""Shallow copy implementation."""
"""Shallow copy implementation."""
#TODO(minjie): too many states in python; should clean up and lower to C
#
TODO(minjie): too many states in python; should clean up and lower to C
cls
=
type
(
self
)
cls
=
type
(
self
)
obj
=
cls
.
__new__
(
cls
)
obj
=
cls
.
__new__
(
cls
)
obj
.
__dict__
.
update
(
self
.
__dict__
)
obj
.
__dict__
.
update
(
self
.
__dict__
)
...
@@ -298,14 +377,16 @@ class DGLGraph(object):
...
@@ -298,14 +377,16 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support add_nodes
# TODO(xiangsx): block do not support add_nodes
if
ntype
is
None
:
if
ntype
is
None
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
raise
DGLError
(
'node types.'
)
"Node type name must be specified if there are more than one "
"node types."
)
# nothing happen
# nothing happen
if
num
==
0
:
if
num
==
0
:
return
return
assert
num
>
0
,
'
Number of new nodes should be larger than one.
'
assert
num
>
0
,
"
Number of new nodes should be larger than one.
"
ntid
=
self
.
get_ntype_id
(
ntype
)
ntid
=
self
.
get_ntype_id
(
ntype
)
# update graph idx
# update graph idx
metagraph
=
self
.
_graph
.
metagraph
metagraph
=
self
.
_graph
.
metagraph
...
@@ -319,23 +400,32 @@ class DGLGraph(object):
...
@@ -319,23 +400,32 @@ class DGLGraph(object):
relation_graphs
=
[]
relation_graphs
=
[]
for
c_etype
in
self
.
canonical_etypes
:
for
c_etype
in
self
.
canonical_etypes
:
# src or dst == ntype, update the relation graph
# src or dst == ntype, update the relation graph
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
or
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
:
if
(
u
,
v
=
self
.
edges
(
form
=
'uv'
,
order
=
'eid'
,
etype
=
c_etype
)
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
or
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
):
u
,
v
=
self
.
edges
(
form
=
"uv"
,
order
=
"eid"
,
etype
=
c_etype
)
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
if
c_etype
[
0
]
==
c_etype
[
2
]
else
2
,
1
if
c_etype
[
0
]
==
c_etype
[
2
]
else
2
,
self
.
number_of_nodes
(
c_etype
[
0
])
+
\
self
.
number_of_nodes
(
c_etype
[
0
])
(
num
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
else
0
),
+
(
num
if
self
.
get_ntype_id
(
c_etype
[
0
])
==
ntid
else
0
),
self
.
number_of_nodes
(
c_etype
[
2
])
+
\
self
.
number_of_nodes
(
c_etype
[
2
])
(
num
if
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
else
0
),
+
(
num
if
self
.
get_ntype_id
(
c_etype
[
2
])
==
ntid
else
0
),
u
,
u
,
v
,
v
,
[
'coo'
,
'csr'
,
'csc'
])
[
"coo"
,
"csr"
,
"csc"
],
)
relation_graphs
.
append
(
hgidx
)
relation_graphs
.
append
(
hgidx
)
else
:
else
:
# do nothing
# do nothing
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
)))
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
)
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
self
.
_graph
=
hgidx
self
.
_graph
=
hgidx
# update data frames
# update data frames
...
@@ -452,26 +542,33 @@ class DGLGraph(object):
...
@@ -452,26 +542,33 @@ class DGLGraph(object):
remove_edges
remove_edges
"""
"""
# TODO(xiangsx): block do not support add_edges
# TODO(xiangsx): block do not support add_edges
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
etype
is
None
:
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
raise
DGLError
(
'edge types.'
)
"Edge type name must be specified if there are more than one "
"edge types."
)
# nothing changed
# nothing changed
if
len
(
u
)
==
0
or
len
(
v
)
==
0
:
if
len
(
u
)
==
0
or
len
(
v
)
==
0
:
return
return
assert
len
(
u
)
==
len
(
v
)
or
len
(
u
)
==
1
or
len
(
v
)
==
1
,
\
assert
len
(
u
)
==
len
(
v
)
or
len
(
u
)
==
1
or
len
(
v
)
==
1
,
(
'The number of source nodes and the number of destination nodes should be same, '
\
"The number of source nodes and the number of destination nodes should be same, "
'or either the number of source nodes or the number of destination nodes is 1.'
"or either the number of source nodes or the number of destination nodes is 1."
)
if
len
(
u
)
==
1
and
len
(
v
)
>
1
:
if
len
(
u
)
==
1
and
len
(
v
)
>
1
:
u
=
F
.
full_1d
(
len
(
v
),
F
.
as_scalar
(
u
),
dtype
=
F
.
dtype
(
u
),
ctx
=
F
.
context
(
u
))
u
=
F
.
full_1d
(
len
(
v
),
F
.
as_scalar
(
u
),
dtype
=
F
.
dtype
(
u
),
ctx
=
F
.
context
(
u
)
)
if
len
(
v
)
==
1
and
len
(
u
)
>
1
:
if
len
(
v
)
==
1
and
len
(
u
)
>
1
:
v
=
F
.
full_1d
(
len
(
u
),
F
.
as_scalar
(
v
),
dtype
=
F
.
dtype
(
v
),
ctx
=
F
.
context
(
v
))
v
=
F
.
full_1d
(
len
(
u
),
F
.
as_scalar
(
v
),
dtype
=
F
.
dtype
(
v
),
ctx
=
F
.
context
(
v
)
)
u_type
,
e_type
,
v_type
=
self
.
to_canonical_etype
(
etype
)
u_type
,
e_type
,
v_type
=
self
.
to_canonical_etype
(
etype
)
# if end nodes of adding edges does not exists
# if end nodes of adding edges does not exists
...
@@ -501,22 +598,28 @@ class DGLGraph(object):
...
@@ -501,22 +598,28 @@ class DGLGraph(object):
for
c_etype
in
self
.
canonical_etypes
:
for
c_etype
in
self
.
canonical_etypes
:
# the target edge type
# the target edge type
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
old_u
,
old_v
=
self
.
edges
(
form
=
'
uv
'
,
order
=
'
eid
'
,
etype
=
c_etype
)
old_u
,
old_v
=
self
.
edges
(
form
=
"
uv
"
,
order
=
"
eid
"
,
etype
=
c_etype
)
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
1
if
u_type
==
v_type
else
2
,
1
if
u_type
==
v_type
else
2
,
self
.
number_of_nodes
(
u_type
),
self
.
number_of_nodes
(
u_type
),
self
.
number_of_nodes
(
v_type
),
self
.
number_of_nodes
(
v_type
),
F
.
cat
([
old_u
,
u
],
dim
=
0
),
F
.
cat
([
old_u
,
u
],
dim
=
0
),
F
.
cat
([
old_v
,
v
],
dim
=
0
),
F
.
cat
([
old_v
,
v
],
dim
=
0
),
[
'coo'
,
'csr'
,
'csc'
])
[
"coo"
,
"csr"
,
"csc"
],
)
relation_graphs
.
append
(
hgidx
)
relation_graphs
.
append
(
hgidx
)
else
:
else
:
# do nothing
# do nothing
# Note: node range change has been handled in add_nodes()
# Note: node range change has been handled in add_nodes()
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
)))
relation_graphs
.
append
(
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
)
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
hgidx
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
self
.
_graph
=
hgidx
self
.
_graph
=
hgidx
# handle data
# handle data
...
@@ -607,15 +710,19 @@ class DGLGraph(object):
...
@@ -607,15 +710,19 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_edges
# TODO(xiangsx): block do not support remove_edges
if
etype
is
None
:
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
\
raise
DGLError
(
'edge types.'
)
"Edge type name must be specified if there are more than one "
eids
=
utils
.
prepare_tensor
(
self
,
eids
,
'u'
)
"edge types."
)
eids
=
utils
.
prepare_tensor
(
self
,
eids
,
"u"
)
if
len
(
eids
)
==
0
:
if
len
(
eids
)
==
0
:
# no edge to delete
# no edge to delete
return
return
assert
self
.
number_of_edges
(
etype
)
>
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
\
assert
self
.
number_of_edges
(
etype
)
>
F
.
as_scalar
(
'The input eid {} is out of the range [0:{})'
.
format
(
F
.
max
(
eids
,
dim
=
0
)
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
self
.
number_of_edges
(
etype
))
),
"The input eid {} is out of the range [0:{})"
.
format
(
F
.
as_scalar
(
F
.
max
(
eids
,
dim
=
0
)),
self
.
number_of_edges
(
etype
)
)
# edge_subgraph
# edge_subgraph
edges
=
{}
edges
=
{}
...
@@ -623,25 +730,36 @@ class DGLGraph(object):
...
@@ -623,25 +730,36 @@ class DGLGraph(object):
for
c_etype
in
self
.
canonical_etypes
:
for
c_etype
in
self
.
canonical_etypes
:
# the target edge type
# the target edge type
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
if
c_etype
==
(
u_type
,
e_type
,
v_type
):
origin_eids
=
self
.
edges
(
form
=
'
eid
'
,
order
=
'
eid
'
,
etype
=
c_etype
)
origin_eids
=
self
.
edges
(
form
=
"
eid
"
,
order
=
"
eid
"
,
etype
=
c_etype
)
edges
[
c_etype
]
=
utils
.
compensate
(
eids
,
origin_eids
)
edges
[
c_etype
]
=
utils
.
compensate
(
eids
,
origin_eids
)
else
:
else
:
edges
[
c_etype
]
=
self
.
edges
(
form
=
'eid'
,
order
=
'eid'
,
etype
=
c_etype
)
edges
[
c_etype
]
=
self
.
edges
(
form
=
"eid"
,
order
=
"eid"
,
etype
=
c_etype
)
# If the graph is batched, update batch_num_edges
# If the graph is batched, update batch_num_edges
batched
=
self
.
_batch_num_edges
is
not
None
batched
=
self
.
_batch_num_edges
is
not
None
if
batched
:
if
batched
:
c_etype
=
(
u_type
,
e_type
,
v_type
)
c_etype
=
(
u_type
,
e_type
,
v_type
)
one_hot_removed_edges
=
F
.
zeros
((
self
.
num_edges
(
c_etype
),),
F
.
float32
,
self
.
device
)
one_hot_removed_edges
=
F
.
zeros
(
one_hot_removed_edges
=
F
.
scatter_row
(
one_hot_removed_edges
,
eids
,
(
self
.
num_edges
(
c_etype
),),
F
.
float32
,
self
.
device
F
.
full_1d
(
len
(
eids
),
1.
,
F
.
float32
,
self
.
device
))
)
one_hot_removed_edges
=
F
.
scatter_row
(
one_hot_removed_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.0
,
F
.
float32
,
self
.
device
),
)
c_etype_batch_num_edges
=
self
.
_batch_num_edges
[
c_etype
]
c_etype_batch_num_edges
=
self
.
_batch_num_edges
[
c_etype
]
batch_num_removed_edges
=
segment
.
segment_reduce
(
c_etype_batch_num_edges
,
batch_num_removed_edges
=
segment
.
segment_reduce
(
one_hot_removed_edges
,
reducer
=
'sum'
)
c_etype_batch_num_edges
,
one_hot_removed_edges
,
reducer
=
"sum"
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
\
)
F
.
astype
(
batch_num_removed_edges
,
F
.
int64
)
self
.
_batch_num_edges
[
c_etype
]
=
c_etype_batch_num_edges
-
F
.
astype
(
batch_num_removed_edges
,
F
.
int64
sub_g
=
self
.
edge_subgraph
(
edges
,
relabel_nodes
=
False
,
store_ids
=
store_ids
)
)
sub_g
=
self
.
edge_subgraph
(
edges
,
relabel_nodes
=
False
,
store_ids
=
store_ids
)
self
.
_graph
=
sub_g
.
_graph
self
.
_graph
=
sub_g
.
_graph
self
.
_node_frames
=
sub_g
.
_node_frames
self
.
_node_frames
=
sub_g
.
_node_frames
self
.
_edge_frames
=
sub_g
.
_edge_frames
self
.
_edge_frames
=
sub_g
.
_edge_frames
...
@@ -733,16 +851,20 @@ class DGLGraph(object):
...
@@ -733,16 +851,20 @@ class DGLGraph(object):
# TODO(xiangsx): block do not support remove_nodes
# TODO(xiangsx): block do not support remove_nodes
if
ntype
is
None
:
if
ntype
is
None
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
if
self
.
_graph
.
number_of_ntypes
()
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
\
raise
DGLError
(
'node types.'
)
"Node type name must be specified if there are more than one "
"node types."
)
nids
=
utils
.
prepare_tensor
(
self
,
nids
,
'u'
)
nids
=
utils
.
prepare_tensor
(
self
,
nids
,
"u"
)
if
len
(
nids
)
==
0
:
if
len
(
nids
)
==
0
:
# no node to delete
# no node to delete
return
return
assert
self
.
number_of_nodes
(
ntype
)
>
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
\
assert
self
.
number_of_nodes
(
ntype
)
>
F
.
as_scalar
(
'The input nids {} is out of the range [0:{})'
.
format
(
F
.
max
(
nids
,
dim
=
0
)
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
self
.
number_of_nodes
(
ntype
))
),
"The input nids {} is out of the range [0:{})"
.
format
(
F
.
as_scalar
(
F
.
max
(
nids
,
dim
=
0
)),
self
.
number_of_nodes
(
ntype
)
)
ntid
=
self
.
get_ntype_id
(
ntype
)
ntid
=
self
.
get_ntype_id
(
ntype
)
nodes
=
{}
nodes
=
{}
...
@@ -757,18 +879,28 @@ class DGLGraph(object):
...
@@ -757,18 +879,28 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_nodes
# If the graph is batched, update batch_num_nodes
batched
=
self
.
_batch_num_nodes
is
not
None
batched
=
self
.
_batch_num_nodes
is
not
None
if
batched
:
if
batched
:
one_hot_removed_nodes
=
F
.
zeros
((
self
.
num_nodes
(
target_ntype
),),
one_hot_removed_nodes
=
F
.
zeros
(
F
.
float32
,
self
.
device
)
(
self
.
num_nodes
(
target_ntype
),),
F
.
float32
,
self
.
device
one_hot_removed_nodes
=
F
.
scatter_row
(
one_hot_removed_nodes
,
nids
,
)
F
.
full_1d
(
len
(
nids
),
1.
,
F
.
float32
,
self
.
device
))
one_hot_removed_nodes
=
F
.
scatter_row
(
one_hot_removed_nodes
,
nids
,
F
.
full_1d
(
len
(
nids
),
1.0
,
F
.
float32
,
self
.
device
),
)
c_ntype_batch_num_nodes
=
self
.
_batch_num_nodes
[
target_ntype
]
c_ntype_batch_num_nodes
=
self
.
_batch_num_nodes
[
target_ntype
]
batch_num_removed_nodes
=
segment
.
segment_reduce
(
batch_num_removed_nodes
=
segment
.
segment_reduce
(
c_ntype_batch_num_nodes
,
one_hot_removed_nodes
,
reducer
=
'sum'
)
c_ntype_batch_num_nodes
,
one_hot_removed_nodes
,
reducer
=
"sum"
self
.
_batch_num_nodes
[
target_ntype
]
=
c_ntype_batch_num_nodes
-
\
)
F
.
astype
(
batch_num_removed_nodes
,
F
.
int64
)
self
.
_batch_num_nodes
[
target_ntype
]
=
c_ntype_batch_num_nodes
-
F
.
astype
(
batch_num_removed_nodes
,
F
.
int64
)
# Record old num_edges to check later whether some edges were removed
# Record old num_edges to check later whether some edges were removed
old_num_edges
=
{
c_etype
:
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
old_num_edges
=
{
for
c_etype
in
self
.
canonical_etypes
}
c_etype
:
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
}
# node_subgraph
# node_subgraph
# If batch_num_edges is to be updated, record the original edge IDs
# If batch_num_edges is to be updated, record the original edge IDs
...
@@ -780,22 +912,36 @@ class DGLGraph(object):
...
@@ -780,22 +912,36 @@ class DGLGraph(object):
# If the graph is batched, update batch_num_edges
# If the graph is batched, update batch_num_edges
if
batched
:
if
batched
:
canonical_etypes
=
[
canonical_etypes
=
[
c_etype
for
c_etype
in
self
.
canonical_etypes
if
c_etype
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
!=
old_num_edges
[
c_etype
]]
for
c_etype
in
self
.
canonical_etypes
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
!=
old_num_edges
[
c_etype
]
]
for
c_etype
in
canonical_etypes
:
for
c_etype
in
canonical_etypes
:
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
==
0
:
if
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
c_etype
))
==
0
:
self
.
_batch_num_edges
[
c_etype
]
=
F
.
zeros
(
self
.
_batch_num_edges
[
c_etype
]
=
F
.
zeros
(
(
self
.
batch_size
,),
F
.
int64
,
self
.
device
)
(
self
.
batch_size
,),
F
.
int64
,
self
.
device
)
continue
continue
one_hot_left_edges
=
F
.
zeros
((
old_num_edges
[
c_etype
],),
F
.
float32
,
self
.
device
)
one_hot_left_edges
=
F
.
zeros
(
(
old_num_edges
[
c_etype
],),
F
.
float32
,
self
.
device
)
eids
=
self
.
edges
[
c_etype
].
data
[
EID
]
eids
=
self
.
edges
[
c_etype
].
data
[
EID
]
one_hot_left_edges
=
F
.
scatter_row
(
one_hot_left_edges
,
eids
,
one_hot_left_edges
=
F
.
scatter_row
(
F
.
full_1d
(
len
(
eids
),
1.
,
F
.
float32
,
self
.
device
))
one_hot_left_edges
,
eids
,
F
.
full_1d
(
len
(
eids
),
1.0
,
F
.
float32
,
self
.
device
),
)
batch_num_left_edges
=
segment
.
segment_reduce
(
batch_num_left_edges
=
segment
.
segment_reduce
(
self
.
_batch_num_edges
[
c_etype
],
one_hot_left_edges
,
reducer
=
'sum'
)
self
.
_batch_num_edges
[
c_etype
],
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
batch_num_left_edges
,
F
.
int64
)
one_hot_left_edges
,
reducer
=
"sum"
,
)
self
.
_batch_num_edges
[
c_etype
]
=
F
.
astype
(
batch_num_left_edges
,
F
.
int64
)
if
batched
and
not
store_ids
:
if
batched
and
not
store_ids
:
for
c_ntype
in
self
.
ntypes
:
for
c_ntype
in
self
.
ntypes
:
...
@@ -810,7 +956,6 @@ class DGLGraph(object):
...
@@ -810,7 +956,6 @@ class DGLGraph(object):
self
.
_batch_num_nodes
=
None
self
.
_batch_num_nodes
=
None
self
.
_batch_num_edges
=
None
self
.
_batch_num_edges
=
None
#################################################################
#################################################################
# Metagraph query
# Metagraph query
#################################################################
#################################################################
...
@@ -1080,7 +1225,9 @@ class DGLGraph(object):
...
@@ -1080,7 +1225,9 @@ class DGLGraph(object):
nx_graph
=
self
.
_graph
.
metagraph
.
to_networkx
()
nx_graph
=
self
.
_graph
.
metagraph
.
to_networkx
()
nx_metagraph
=
nx
.
MultiDiGraph
()
nx_metagraph
=
nx
.
MultiDiGraph
()
for
u_v
in
nx_graph
.
edges
:
for
u_v
in
nx_graph
.
edges
:
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
'id'
]]
srctype
,
etype
,
dsttype
=
self
.
canonical_etypes
[
nx_graph
.
edges
[
u_v
][
"id"
]
]
nx_metagraph
.
add_edge
(
srctype
,
dsttype
,
etype
)
nx_metagraph
.
add_edge
(
srctype
,
dsttype
,
etype
)
return
nx_metagraph
return
nx_metagraph
...
@@ -1133,8 +1280,10 @@ class DGLGraph(object):
...
@@ -1133,8 +1280,10 @@ class DGLGraph(object):
"""
"""
if
etype
is
None
:
if
etype
is
None
:
if
len
(
self
.
etypes
)
!=
1
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
raise
DGLError
(
'edge types.'
)
"Edge type name must be specified if there are more than one "
"edge types."
)
etype
=
self
.
etypes
[
0
]
etype
=
self
.
etypes
[
0
]
if
isinstance
(
etype
,
tuple
):
if
isinstance
(
etype
,
tuple
):
return
etype
return
etype
...
@@ -1143,8 +1292,10 @@ class DGLGraph(object):
...
@@ -1143,8 +1292,10 @@ class DGLGraph(object):
if
ret
is
None
:
if
ret
is
None
:
raise
DGLError
(
'Edge type "{}" does not exist.'
.
format
(
etype
))
raise
DGLError
(
'Edge type "{}" does not exist.'
.
format
(
etype
))
if
len
(
ret
)
==
0
:
if
len
(
ret
)
==
0
:
raise
DGLError
(
'Edge type "%s" is ambiguous. Please use canonical edge type '
raise
DGLError
(
'in the form of (srctype, etype, dsttype)'
%
etype
)
'Edge type "%s" is ambiguous. Please use canonical edge type '
"in the form of (srctype, etype, dsttype)"
%
etype
)
return
ret
return
ret
def
get_ntype_id
(
self
,
ntype
):
def
get_ntype_id
(
self
,
ntype
):
...
@@ -1164,19 +1315,23 @@ class DGLGraph(object):
...
@@ -1164,19 +1315,23 @@ class DGLGraph(object):
"""
"""
if
self
.
is_unibipartite
and
ntype
is
not
None
:
if
self
.
is_unibipartite
and
ntype
is
not
None
:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if
ntype
.
startswith
(
'
SRC/
'
):
if
ntype
.
startswith
(
"
SRC/
"
):
return
self
.
get_ntype_id_from_src
(
ntype
[
4
:])
return
self
.
get_ntype_id_from_src
(
ntype
[
4
:])
elif
ntype
.
startswith
(
'
DST/
'
):
elif
ntype
.
startswith
(
"
DST/
"
):
return
self
.
get_ntype_id_from_dst
(
ntype
[
4
:])
return
self
.
get_ntype_id_from_dst
(
ntype
[
4
:])
# If there is no prefix, fallback to normal lookup.
# If there is no prefix, fallback to normal lookup.
# Lookup both SRC and DST
# Lookup both SRC and DST
if
ntype
is
None
:
if
ntype
is
None
:
if
self
.
is_unibipartite
or
len
(
self
.
_srctypes_invmap
)
!=
1
:
if
self
.
is_unibipartite
or
len
(
self
.
_srctypes_invmap
)
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
raise
DGLError
(
'node types.'
)
"Node type name must be specified if there are more than one "
"node types."
)
return
0
return
0
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
))
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
)
if
ntid
is
None
:
if
ntid
is
None
:
raise
DGLError
(
'Node type "{}" does not exist.'
.
format
(
ntype
))
raise
DGLError
(
'Node type "{}" does not exist.'
.
format
(
ntype
))
return
ntid
return
ntid
...
@@ -1198,8 +1353,10 @@ class DGLGraph(object):
...
@@ -1198,8 +1353,10 @@ class DGLGraph(object):
"""
"""
if
ntype
is
None
:
if
ntype
is
None
:
if
len
(
self
.
_srctypes_invmap
)
!=
1
:
if
len
(
self
.
_srctypes_invmap
)
!=
1
:
raise
DGLError
(
'SRC node type name must be specified if there are more than one '
raise
DGLError
(
'SRC node types.'
)
"SRC node type name must be specified if there are more than one "
"SRC node types."
)
return
next
(
iter
(
self
.
_srctypes_invmap
.
values
()))
return
next
(
iter
(
self
.
_srctypes_invmap
.
values
()))
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
None
)
ntid
=
self
.
_srctypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
if
ntid
is
None
:
...
@@ -1223,8 +1380,10 @@ class DGLGraph(object):
...
@@ -1223,8 +1380,10 @@ class DGLGraph(object):
"""
"""
if
ntype
is
None
:
if
ntype
is
None
:
if
len
(
self
.
_dsttypes_invmap
)
!=
1
:
if
len
(
self
.
_dsttypes_invmap
)
!=
1
:
raise
DGLError
(
'DST node type name must be specified if there are more than one '
raise
DGLError
(
'DST node types.'
)
"DST node type name must be specified if there are more than one "
"DST node types."
)
return
next
(
iter
(
self
.
_dsttypes_invmap
.
values
()))
return
next
(
iter
(
self
.
_dsttypes_invmap
.
values
()))
ntid
=
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
ntid
=
self
.
_dsttypes_invmap
.
get
(
ntype
,
None
)
if
ntid
is
None
:
if
ntid
is
None
:
...
@@ -1248,8 +1407,10 @@ class DGLGraph(object):
...
@@ -1248,8 +1407,10 @@ class DGLGraph(object):
"""
"""
if
etype
is
None
:
if
etype
is
None
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
if
self
.
_graph
.
number_of_etypes
()
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
raise
DGLError
(
'edge types.'
)
"Edge type name must be specified if there are more than one "
"edge types."
)
return
0
return
0
etid
=
self
.
_etypes_invmap
.
get
(
self
.
to_canonical_etype
(
etype
),
None
)
etid
=
self
.
_etypes_invmap
.
get
(
self
.
to_canonical_etype
(
etype
),
None
)
if
etid
is
None
:
if
etid
is
None
:
...
@@ -1346,17 +1507,23 @@ class DGLGraph(object):
...
@@ -1346,17 +1507,23 @@ class DGLGraph(object):
tensor([2, 1])
tensor([2, 1])
"""
"""
if
ntype
is
not
None
and
ntype
not
in
self
.
ntypes
:
if
ntype
is
not
None
and
ntype
not
in
self
.
ntypes
:
raise
DGLError
(
'Expect ntype in {}, got {}'
.
format
(
self
.
ntypes
,
ntype
))
raise
DGLError
(
"Expect ntype in {}, got {}"
.
format
(
self
.
ntypes
,
ntype
)
)
if
self
.
_batch_num_nodes
is
None
:
if
self
.
_batch_num_nodes
is
None
:
self
.
_batch_num_nodes
=
{}
self
.
_batch_num_nodes
=
{}
for
ty
in
self
.
ntypes
:
for
ty
in
self
.
ntypes
:
bnn
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_nodes
(
ty
)],
F
.
int64
),
self
.
device
)
bnn
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_nodes
(
ty
)],
F
.
int64
),
self
.
device
)
self
.
_batch_num_nodes
[
ty
]
=
bnn
self
.
_batch_num_nodes
[
ty
]
=
bnn
if
ntype
is
None
:
if
ntype
is
None
:
if
len
(
self
.
ntypes
)
!=
1
:
if
len
(
self
.
ntypes
)
!=
1
:
raise
DGLError
(
'Node type name must be specified if there are more than one '
raise
DGLError
(
'node types.'
)
"Node type name must be specified if there are more than one "
"node types."
)
ntype
=
self
.
ntypes
[
0
]
ntype
=
self
.
ntypes
[
0
]
return
self
.
_batch_num_nodes
[
ntype
]
return
self
.
_batch_num_nodes
[
ntype
]
...
@@ -1440,8 +1607,10 @@ class DGLGraph(object):
...
@@ -1440,8 +1607,10 @@ class DGLGraph(object):
"""
"""
if
not
isinstance
(
val
,
Mapping
):
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
ntypes
)
!=
1
:
if
len
(
self
.
ntypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary when there are multiple node types.'
)
raise
DGLError
(
val
=
{
self
.
ntypes
[
0
]
:
val
}
"Must provide a dictionary when there are multiple node types."
)
val
=
{
self
.
ntypes
[
0
]:
val
}
self
.
_batch_num_nodes
=
val
self
.
_batch_num_nodes
=
val
def
batch_num_edges
(
self
,
etype
=
None
):
def
batch_num_edges
(
self
,
etype
=
None
):
...
@@ -1494,12 +1663,16 @@ class DGLGraph(object):
...
@@ -1494,12 +1663,16 @@ class DGLGraph(object):
if
self
.
_batch_num_edges
is
None
:
if
self
.
_batch_num_edges
is
None
:
self
.
_batch_num_edges
=
{}
self
.
_batch_num_edges
=
{}
for
ty
in
self
.
canonical_etypes
:
for
ty
in
self
.
canonical_etypes
:
bne
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_edges
(
ty
)],
F
.
int64
),
self
.
device
)
bne
=
F
.
copy_to
(
F
.
tensor
([
self
.
number_of_edges
(
ty
)],
F
.
int64
),
self
.
device
)
self
.
_batch_num_edges
[
ty
]
=
bne
self
.
_batch_num_edges
[
ty
]
=
bne
if
etype
is
None
:
if
etype
is
None
:
if
len
(
self
.
etypes
)
!=
1
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Edge type name must be specified if there are more than one '
raise
DGLError
(
'edge types.'
)
"Edge type name must be specified if there are more than one "
"edge types."
)
etype
=
self
.
canonical_etypes
[
0
]
etype
=
self
.
canonical_etypes
[
0
]
else
:
else
:
etype
=
self
.
to_canonical_etype
(
etype
)
etype
=
self
.
to_canonical_etype
(
etype
)
...
@@ -1585,8 +1758,10 @@ class DGLGraph(object):
...
@@ -1585,8 +1758,10 @@ class DGLGraph(object):
"""
"""
if
not
isinstance
(
val
,
Mapping
):
if
not
isinstance
(
val
,
Mapping
):
if
len
(
self
.
etypes
)
!=
1
:
if
len
(
self
.
etypes
)
!=
1
:
raise
DGLError
(
'Must provide a dictionary when there are multiple edge types.'
)
raise
DGLError
(
val
=
{
self
.
canonical_etypes
[
0
]
:
val
}
"Must provide a dictionary when there are multiple edge types."
)
val
=
{
self
.
canonical_etypes
[
0
]:
val
}
self
.
_batch_num_edges
=
val
self
.
_batch_num_edges
=
val
#################################################################
#################################################################
...
@@ -2130,10 +2305,14 @@ class DGLGraph(object):
...
@@ -2130,10 +2305,14 @@ class DGLGraph(object):
def
_find_etypes
(
self
,
key
):
def
_find_etypes
(
self
,
key
):
etypes
=
[
etypes
=
[
i
for
i
,
(
srctype
,
etype
,
dsttype
)
in
enumerate
(
self
.
_canonical_etypes
)
if
i
(
key
[
0
]
==
SLICE_FULL
or
key
[
0
]
==
srctype
)
and
for
i
,
(
srctype
,
etype
,
dsttype
)
in
enumerate
(
(
key
[
1
]
==
SLICE_FULL
or
key
[
1
]
==
etype
)
and
self
.
_canonical_etypes
(
key
[
2
]
==
SLICE_FULL
or
key
[
2
]
==
dsttype
)]
)
if
(
key
[
0
]
==
SLICE_FULL
or
key
[
0
]
==
srctype
)
and
(
key
[
1
]
==
SLICE_FULL
or
key
[
1
]
==
etype
)
and
(
key
[
2
]
==
SLICE_FULL
or
key
[
2
]
==
dsttype
)
]
return
etypes
return
etypes
def
__getitem__
(
self
,
key
):
def
__getitem__
(
self
,
key
):
...
@@ -2215,9 +2394,11 @@ class DGLGraph(object):
...
@@ -2215,9 +2394,11 @@ class DGLGraph(object):
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
"""
err_msg
=
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
+
\
err_msg
=
(
"to get view of one relation type. Use : to slice multiple types (e.g. "
+
\
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
"G['srctype', :, 'dsttype'])."
+
"to get view of one relation type. Use : to slice multiple types (e.g. "
+
"G['srctype', :, 'dsttype'])."
)
orig_key
=
key
orig_key
=
key
if
not
isinstance
(
key
,
tuple
):
if
not
isinstance
(
key
,
tuple
):
...
@@ -2229,7 +2410,11 @@ class DGLGraph(object):
...
@@ -2229,7 +2410,11 @@ class DGLGraph(object):
etypes
=
self
.
_find_etypes
(
key
)
etypes
=
self
.
_find_etypes
(
key
)
if
len
(
etypes
)
==
0
:
if
len
(
etypes
)
==
0
:
raise
DGLError
(
'Invalid key "{}". Must be one of the edge types.'
.
format
(
orig_key
))
raise
DGLError
(
'Invalid key "{}". Must be one of the edge types.'
.
format
(
orig_key
)
)
if
len
(
etypes
)
==
1
:
if
len
(
etypes
)
==
1
:
# no ambiguity: return the unitgraph itself
# no ambiguity: return the unitgraph itself
...
@@ -2248,7 +2433,9 @@ class DGLGraph(object):
...
@@ -2248,7 +2433,9 @@ class DGLGraph(object):
new_etypes
=
[
etype
]
new_etypes
=
[
etype
]
new_eframes
=
[
self
.
_edge_frames
[
etid
]]
new_eframes
=
[
self
.
_edge_frames
[
etid
]]
return
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
return
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
else
:
else
:
flat
=
self
.
_graph
.
flatten_relations
(
etypes
)
flat
=
self
.
_graph
.
flatten_relations
(
etypes
)
new_g
=
flat
.
graph
new_g
=
flat
.
graph
...
@@ -2262,7 +2449,8 @@ class DGLGraph(object):
...
@@ -2262,7 +2449,8 @@ class DGLGraph(object):
new_ntypes
.
append
(
combine_names
(
self
.
ntypes
,
dtids
))
new_ntypes
.
append
(
combine_names
(
self
.
ntypes
,
dtids
))
new_nframes
=
[
new_nframes
=
[
combine_frames
(
self
.
_node_frames
,
stids
),
combine_frames
(
self
.
_node_frames
,
stids
),
combine_frames
(
self
.
_node_frames
,
dtids
)]
combine_frames
(
self
.
_node_frames
,
dtids
),
]
else
:
else
:
assert
np
.
array_equal
(
stids
,
dtids
)
assert
np
.
array_equal
(
stids
,
dtids
)
new_nframes
=
[
combine_frames
(
self
.
_node_frames
,
stids
)]
new_nframes
=
[
combine_frames
(
self
.
_node_frames
,
stids
)]
...
@@ -2270,16 +2458,28 @@ class DGLGraph(object):
...
@@ -2270,16 +2458,28 @@ class DGLGraph(object):
new_eframes
=
[
combine_frames
(
self
.
_edge_frames
,
etids
)]
new_eframes
=
[
combine_frames
(
self
.
_edge_frames
,
etids
)]
# create new heterograph
# create new heterograph
new_hg
=
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
new_hg
=
self
.
__class__
(
new_g
,
new_ntypes
,
new_etypes
,
new_nframes
,
new_eframes
)
src
=
new_ntypes
[
0
]
src
=
new_ntypes
[
0
]
dst
=
new_ntypes
[
1
]
if
new_g
.
number_of_ntypes
()
==
2
else
src
dst
=
new_ntypes
[
1
]
if
new_g
.
number_of_ntypes
()
==
2
else
src
# put the parent node/edge type and IDs
# put the parent node/edge type and IDs
new_hg
.
nodes
[
src
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srctype
)
new_hg
.
nodes
[
src
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
new_hg
.
nodes
[
src
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_srcid
)
flat
.
induced_srctype
new_hg
.
nodes
[
dst
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dsttype
)
)
new_hg
.
nodes
[
dst
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dstid
)
new_hg
.
nodes
[
src
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
new_hg
.
edata
[
ETYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_etype
)
flat
.
induced_srcid
)
new_hg
.
nodes
[
dst
].
data
[
NTYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dsttype
)
new_hg
.
nodes
[
dst
].
data
[
NID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_dstid
)
new_hg
.
edata
[
ETYPE
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_etype
)
new_hg
.
edata
[
EID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_eid
)
new_hg
.
edata
[
EID
]
=
F
.
zerocopy_from_dgl_ndarray
(
flat
.
induced_eid
)
return
new_hg
return
new_hg
...
@@ -2331,7 +2531,12 @@ class DGLGraph(object):
...
@@ -2331,7 +2531,12 @@ class DGLGraph(object):
12
12
"""
"""
if
ntype
is
None
:
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
ntid
)
for
ntid
in
range
(
len
(
self
.
ntypes
))])
return
sum
(
[
self
.
_graph
.
number_of_nodes
(
ntid
)
for
ntid
in
range
(
len
(
self
.
ntypes
))
]
)
else
:
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id
(
ntype
))
...
@@ -2396,10 +2601,16 @@ class DGLGraph(object):
...
@@ -2396,10 +2601,16 @@ class DGLGraph(object):
7
7
"""
"""
if
ntype
is
None
:
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
nty
))
return
sum
(
for
nty
in
self
.
srctypes
])
[
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
nty
))
for
nty
in
self
.
srctypes
]
)
else
:
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_src
(
ntype
)
)
def
number_of_dst_nodes
(
self
,
ntype
=
None
):
def
number_of_dst_nodes
(
self
,
ntype
=
None
):
"""Alias of :func:`num_dst_nodes`"""
"""Alias of :func:`num_dst_nodes`"""
...
@@ -2462,10 +2673,16 @@ class DGLGraph(object):
...
@@ -2462,10 +2673,16 @@ class DGLGraph(object):
12
12
"""
"""
if
ntype
is
None
:
if
ntype
is
None
:
return
sum
([
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
nty
))
return
sum
(
for
nty
in
self
.
dsttypes
])
[
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
nty
))
for
nty
in
self
.
dsttypes
]
)
else
:
else
:
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
ntype
))
return
self
.
_graph
.
number_of_nodes
(
self
.
get_ntype_id_from_dst
(
ntype
)
)
def
number_of_edges
(
self
,
etype
=
None
):
def
number_of_edges
(
self
,
etype
=
None
):
"""Alias of :func:`num_edges`"""
"""Alias of :func:`num_edges`"""
...
@@ -2522,8 +2739,12 @@ class DGLGraph(object):
...
@@ -2522,8 +2739,12 @@ class DGLGraph(object):
3
3
"""
"""
if
etype
is
None
:
if
etype
is
None
:
return
sum
([
self
.
_graph
.
number_of_edges
(
etid
)
return
sum
(
for
etid
in
range
(
len
(
self
.
canonical_etypes
))])
[
self
.
_graph
.
number_of_edges
(
etid
)
for
etid
in
range
(
len
(
self
.
canonical_etypes
))
]
)
else
:
else
:
return
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
etype
))
return
self
.
_graph
.
number_of_edges
(
self
.
get_etype_id
(
etype
))
...
@@ -2708,10 +2929,11 @@ class DGLGraph(object):
...
@@ -2708,10 +2929,11 @@ class DGLGraph(object):
tensor([False, True, True])
tensor([False, True, True])
"""
"""
vid_tensor
=
utils
.
prepare_tensor
(
self
,
vid
,
"vid"
)
vid_tensor
=
utils
.
prepare_tensor
(
self
,
vid
,
"vid"
)
if
len
(
vid_tensor
)
>
0
and
F
.
as_scalar
(
F
.
min
(
vid_tensor
,
0
))
<
0
<
len
(
vid_tensor
):
if
len
(
vid_tensor
)
>
0
and
F
.
as_scalar
(
F
.
min
(
vid_tensor
,
0
))
<
0
<
len
(
raise
DGLError
(
'All IDs must be non-negative integers.'
)
vid_tensor
ret
=
self
.
_graph
.
has_nodes
(
):
self
.
get_ntype_id
(
ntype
),
vid_tensor
)
raise
DGLError
(
"All IDs must be non-negative integers."
)
ret
=
self
.
_graph
.
has_nodes
(
self
.
get_ntype_id
(
ntype
),
vid_tensor
)
if
isinstance
(
vid
,
numbers
.
Integral
):
if
isinstance
(
vid
,
numbers
.
Integral
):
return
bool
(
F
.
as_scalar
(
ret
))
return
bool
(
F
.
as_scalar
(
ret
))
else
:
else
:
...
@@ -2793,15 +3015,19 @@ class DGLGraph(object):
...
@@ -2793,15 +3015,19 @@ class DGLGraph(object):
tensor([True, True])
tensor([True, True])
"""
"""
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u_tensor
):
if
F
.
as_scalar
(
raise
DGLError
(
'u contains invalid node IDs'
)
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
)
!=
len
(
u_tensor
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v_tensor
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v_tensor
):
raise
DGLError
(
"u contains invalid node IDs"
)
raise
DGLError
(
'v contains invalid node IDs'
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v_tensor
,
ntype
=
dsttype
),
dim
=
0
)
)
!=
len
(
v_tensor
):
raise
DGLError
(
"v contains invalid node IDs"
)
ret
=
self
.
_graph
.
has_edges_between
(
ret
=
self
.
_graph
.
has_edges_between
(
self
.
get_etype_id
(
etype
),
self
.
get_etype_id
(
etype
),
u_tensor
,
v_tensor
u_tensor
,
v_tensor
)
)
if
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
):
if
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
):
return
bool
(
F
.
as_scalar
(
ret
))
return
bool
(
F
.
as_scalar
(
ret
))
else
:
else
:
...
@@ -2863,7 +3089,7 @@ class DGLGraph(object):
...
@@ -2863,7 +3089,7 @@ class DGLGraph(object):
successors
successors
"""
"""
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
-
1
]):
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
-
1
]):
raise
DGLError
(
'
Non-existing node ID {}
'
.
format
(
v
))
raise
DGLError
(
"
Non-existing node ID {}
"
.
format
(
v
))
return
self
.
_graph
.
predecessors
(
self
.
get_etype_id
(
etype
),
v
)
return
self
.
_graph
.
predecessors
(
self
.
get_etype_id
(
etype
),
v
)
def
successors
(
self
,
v
,
etype
=
None
):
def
successors
(
self
,
v
,
etype
=
None
):
...
@@ -2921,7 +3147,7 @@ class DGLGraph(object):
...
@@ -2921,7 +3147,7 @@ class DGLGraph(object):
predecessors
predecessors
"""
"""
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
0
]):
if
not
self
.
has_nodes
(
v
,
self
.
to_canonical_etype
(
etype
)[
0
]):
raise
DGLError
(
'
Non-existing node ID {}
'
.
format
(
v
))
raise
DGLError
(
"
Non-existing node ID {}
"
.
format
(
v
))
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
return
self
.
_graph
.
successors
(
self
.
get_etype_id
(
etype
),
v
)
def
edge_ids
(
self
,
u
,
v
,
return_uv
=
False
,
etype
=
None
):
def
edge_ids
(
self
,
u
,
v
,
return_uv
=
False
,
etype
=
None
):
...
@@ -3018,14 +3244,20 @@ class DGLGraph(object):
...
@@ -3018,14 +3244,20 @@ class DGLGraph(object):
... etype=('user', 'follows', 'game'))
... etype=('user', 'follows', 'game'))
tensor([1, 2])
tensor([1, 2])
"""
"""
is_int
=
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
)
is_int
=
isinstance
(
u
,
numbers
.
Integral
)
and
isinstance
(
v
,
numbers
.
Integral
)
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
raise
DGLError
(
'u contains invalid node IDs'
)
u
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
"u contains invalid node IDs"
)
raise
DGLError
(
'v contains invalid node IDs'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
"v contains invalid node IDs"
)
if
return_uv
:
if
return_uv
:
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
return
self
.
_graph
.
edge_ids_all
(
self
.
get_etype_id
(
etype
),
u
,
v
)
...
@@ -3035,9 +3267,13 @@ class DGLGraph(object):
...
@@ -3035,9 +3267,13 @@ class DGLGraph(object):
if
F
.
as_scalar
(
F
.
sum
(
is_neg_one
,
0
)):
if
F
.
as_scalar
(
F
.
sum
(
is_neg_one
,
0
)):
# Raise error since some (u, v) pair is not a valid edge.
# Raise error since some (u, v) pair is not a valid edge.
idx
=
F
.
nonzero_1d
(
is_neg_one
)
idx
=
F
.
nonzero_1d
(
is_neg_one
)
raise
DGLError
(
"Error: (%d, %d) does not form a valid edge."
%
(
raise
DGLError
(
"Error: (%d, %d) does not form a valid edge."
%
(
F
.
as_scalar
(
F
.
gather_row
(
u
,
idx
)),
F
.
as_scalar
(
F
.
gather_row
(
u
,
idx
)),
F
.
as_scalar
(
F
.
gather_row
(
v
,
idx
))))
F
.
as_scalar
(
F
.
gather_row
(
v
,
idx
)),
)
)
return
F
.
as_scalar
(
eid
)
if
is_int
else
eid
return
F
.
as_scalar
(
eid
)
if
is_int
else
eid
def
find_edges
(
self
,
eid
,
etype
=
None
):
def
find_edges
(
self
,
eid
,
etype
=
None
):
...
@@ -3096,14 +3332,14 @@ class DGLGraph(object):
...
@@ -3096,14 +3332,14 @@ class DGLGraph(object):
>>> hg.find_edges(torch.tensor([1, 0]), 'plays')
>>> hg.find_edges(torch.tensor([1, 0]), 'plays')
(tensor([4, 3]), tensor([6, 5]))
(tensor([4, 3]), tensor([6, 5]))
"""
"""
eid
=
utils
.
prepare_tensor
(
self
,
eid
,
'
eid
'
)
eid
=
utils
.
prepare_tensor
(
self
,
eid
,
"
eid
"
)
if
len
(
eid
)
>
0
:
if
len
(
eid
)
>
0
:
min_eid
=
F
.
as_scalar
(
F
.
min
(
eid
,
0
))
min_eid
=
F
.
as_scalar
(
F
.
min
(
eid
,
0
))
if
min_eid
<
0
:
if
min_eid
<
0
:
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
min_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
min_eid
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
eid
,
0
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
eid
,
0
))
if
max_eid
>=
self
.
num_edges
(
etype
):
if
max_eid
>=
self
.
num_edges
(
etype
):
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
max_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
max_eid
))
if
len
(
eid
)
==
0
:
if
len
(
eid
)
==
0
:
empty
=
F
.
copy_to
(
F
.
tensor
([],
self
.
idtype
),
self
.
device
)
empty
=
F
.
copy_to
(
F
.
tensor
([],
self
.
idtype
),
self
.
device
)
...
@@ -3111,7 +3347,7 @@ class DGLGraph(object):
...
@@ -3111,7 +3347,7 @@ class DGLGraph(object):
src
,
dst
,
_
=
self
.
_graph
.
find_edges
(
self
.
get_etype_id
(
etype
),
eid
)
src
,
dst
,
_
=
self
.
_graph
.
find_edges
(
self
.
get_etype_id
(
etype
),
eid
)
return
src
,
dst
return
src
,
dst
def
in_edges
(
self
,
v
,
form
=
'
uv
'
,
etype
=
None
):
def
in_edges
(
self
,
v
,
form
=
"
uv
"
,
etype
=
None
):
"""Return the incoming edges of the given nodes.
"""Return the incoming edges of the given nodes.
Parameters
Parameters
...
@@ -3184,18 +3420,20 @@ class DGLGraph(object):
...
@@ -3184,18 +3420,20 @@ class DGLGraph(object):
edges
edges
out_edges
out_edges
"""
"""
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
src
,
dst
,
eid
=
self
.
_graph
.
in_edges
(
self
.
get_etype_id
(
etype
),
v
)
src
,
dst
,
eid
=
self
.
_graph
.
in_edges
(
self
.
get_etype_id
(
etype
),
v
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
return
eid
else
:
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
out_edges
(
self
,
u
,
form
=
'
uv
'
,
etype
=
None
):
def
out_edges
(
self
,
u
,
form
=
"
uv
"
,
etype
=
None
):
"""Return the outgoing edges of the given nodes.
"""Return the outgoing edges of the given nodes.
Parameters
Parameters
...
@@ -3268,21 +3506,25 @@ class DGLGraph(object):
...
@@ -3268,21 +3506,25 @@ class DGLGraph(object):
edges
edges
in_edges
in_edges
"""
"""
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
srctype
,
_
,
_
=
self
.
to_canonical_etype
(
etype
)
srctype
,
_
,
_
=
self
.
to_canonical_etype
(
etype
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
raise
DGLError
(
'u contains invalid node IDs'
)
u
):
raise
DGLError
(
"u contains invalid node IDs"
)
src
,
dst
,
eid
=
self
.
_graph
.
out_edges
(
self
.
get_etype_id
(
etype
),
u
)
src
,
dst
,
eid
=
self
.
_graph
.
out_edges
(
self
.
get_etype_id
(
etype
),
u
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
return
eid
else
:
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
all_edges
(
self
,
form
=
'
uv
'
,
order
=
'
eid
'
,
etype
=
None
):
def
all_edges
(
self
,
form
=
"
uv
"
,
order
=
"
eid
"
,
etype
=
None
):
"""Return all edges with the specified edge type.
"""Return all edges with the specified edge type.
Parameters
Parameters
...
@@ -3353,14 +3595,16 @@ class DGLGraph(object):
...
@@ -3353,14 +3595,16 @@ class DGLGraph(object):
out_edges
out_edges
"""
"""
src
,
dst
,
eid
=
self
.
_graph
.
edges
(
self
.
get_etype_id
(
etype
),
order
)
src
,
dst
,
eid
=
self
.
_graph
.
edges
(
self
.
get_etype_id
(
etype
),
order
)
if
form
==
'
all
'
:
if
form
==
"
all
"
:
return
src
,
dst
,
eid
return
src
,
dst
,
eid
elif
form
==
'
uv
'
:
elif
form
==
"
uv
"
:
return
src
,
dst
return
src
,
dst
elif
form
==
'
eid
'
:
elif
form
==
"
eid
"
:
return
eid
return
eid
else
:
else
:
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
))
raise
DGLError
(
'Invalid form: {}. Must be "all", "uv" or "eid".'
.
format
(
form
)
)
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
def
in_degrees
(
self
,
v
=
ALL
,
etype
=
None
):
"""Return the in-degree(s) of the given nodes.
"""Return the in-degree(s) of the given nodes.
...
@@ -3431,7 +3675,7 @@ class DGLGraph(object):
...
@@ -3431,7 +3675,7 @@ class DGLGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
if
is_all
(
v
):
if
is_all
(
v
):
v
=
self
.
dstnodes
(
dsttype
)
v
=
self
.
dstnodes
(
dsttype
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v_tensor
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
deg
=
self
.
_graph
.
in_degrees
(
etid
,
v_tensor
)
deg
=
self
.
_graph
.
in_degrees
(
etid
,
v_tensor
)
if
isinstance
(
v
,
numbers
.
Integral
):
if
isinstance
(
v
,
numbers
.
Integral
):
return
F
.
as_scalar
(
deg
)
return
F
.
as_scalar
(
deg
)
...
@@ -3507,16 +3751,20 @@ class DGLGraph(object):
...
@@ -3507,16 +3751,20 @@ class DGLGraph(object):
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
if
is_all
(
u
):
if
is_all
(
u
):
u
=
self
.
srcnodes
(
srctype
)
u
=
self
.
srcnodes
(
srctype
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u_tensor
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u_tensor
):
if
F
.
as_scalar
(
raise
DGLError
(
'u contains invalid node IDs'
)
F
.
sum
(
self
.
has_nodes
(
u_tensor
,
ntype
=
srctype
),
dim
=
0
)
deg
=
self
.
_graph
.
out_degrees
(
etid
,
utils
.
prepare_tensor
(
self
,
u
,
'u'
))
)
!=
len
(
u_tensor
):
raise
DGLError
(
"u contains invalid node IDs"
)
deg
=
self
.
_graph
.
out_degrees
(
etid
,
utils
.
prepare_tensor
(
self
,
u
,
"u"
))
if
isinstance
(
u
,
numbers
.
Integral
):
if
isinstance
(
u
,
numbers
.
Integral
):
return
F
.
as_scalar
(
deg
)
return
F
.
as_scalar
(
deg
)
else
:
else
:
return
deg
return
deg
def
adjacency_matrix
(
self
,
transpose
=
False
,
ctx
=
F
.
cpu
(),
scipy_fmt
=
None
,
etype
=
None
):
def
adjacency_matrix
(
self
,
transpose
=
False
,
ctx
=
F
.
cpu
(),
scipy_fmt
=
None
,
etype
=
None
):
"""Alias of :meth:`adj`"""
"""Alias of :meth:`adj`"""
return
self
.
adj
(
transpose
,
ctx
,
scipy_fmt
,
etype
)
return
self
.
adj
(
transpose
,
ctx
,
scipy_fmt
,
etype
)
...
@@ -3586,7 +3834,9 @@ class DGLGraph(object):
...
@@ -3586,7 +3834,9 @@ class DGLGraph(object):
if
scipy_fmt
is
None
:
if
scipy_fmt
is
None
:
return
self
.
_graph
.
adjacency_matrix
(
etid
,
transpose
,
ctx
)[
0
]
return
self
.
_graph
.
adjacency_matrix
(
etid
,
transpose
,
ctx
)[
0
]
else
:
else
:
return
self
.
_graph
.
adjacency_matrix_scipy
(
etid
,
transpose
,
scipy_fmt
,
False
)
return
self
.
_graph
.
adjacency_matrix_scipy
(
etid
,
transpose
,
scipy_fmt
,
False
)
def
adj_sparse
(
self
,
fmt
,
etype
=
None
):
def
adj_sparse
(
self
,
fmt
,
etype
=
None
):
"""Return the adjacency matrix of edges of the given edge type as tensors of
"""Return the adjacency matrix of edges of the given edge type as tensors of
...
@@ -3629,9 +3879,9 @@ class DGLGraph(object):
...
@@ -3629,9 +3879,9 @@ class DGLGraph(object):
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))
"""
"""
etid
=
self
.
get_etype_id
(
etype
)
etid
=
self
.
get_etype_id
(
etype
)
if
fmt
==
'
csc
'
:
if
fmt
==
"
csc
"
:
# The first two elements are number of rows and columns
# The first two elements are number of rows and columns
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
True
,
'
csr
'
)[
2
:]
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
True
,
"
csr
"
)[
2
:]
else
:
else
:
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
return
self
.
_graph
.
adjacency_matrix_tensors
(
etid
,
False
,
fmt
)[
2
:]
...
@@ -4024,26 +4274,36 @@ class DGLGraph(object):
...
@@ -4024,26 +4274,36 @@ class DGLGraph(object):
if
is_all
(
u
):
if
is_all
(
u
):
num_nodes
=
self
.
_graph
.
number_of_nodes
(
ntid
)
num_nodes
=
self
.
_graph
.
number_of_nodes
(
ntid
)
else
:
else
:
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
num_nodes
=
len
(
u
)
num_nodes
=
len
(
u
)
for
key
,
val
in
data
.
items
():
for
key
,
val
in
data
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!=
num_nodes
:
if
nfeats
!=
num_nodes
:
raise
DGLError
(
'Expect number of features to match number of nodes (len(u)).'
raise
DGLError
(
' Got %d and %d instead.'
%
(
nfeats
,
num_nodes
))
"Expect number of features to match number of nodes (len(u))."
" Got %d and %d instead."
%
(
nfeats
,
num_nodes
)
)
if
F
.
context
(
val
)
!=
self
.
device
:
if
F
.
context
(
val
)
!=
self
.
device
:
raise
DGLError
(
'Cannot assign node feature "{}" on device {} to a graph on'
raise
DGLError
(
' device {}. Call DGLGraph.to() to copy the graph to the'
'Cannot assign node feature "{}" on device {} to a graph on'
' same device.'
.
format
(
key
,
F
.
context
(
val
),
self
.
device
))
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device."
.
format
(
key
,
F
.
context
(
val
),
self
.
device
)
)
# To prevent users from doing things like:
# To prevent users from doing things like:
#
#
# g.pin_memory_()
# g.pin_memory_()
# g.ndata['x'] = torch.randn(...)
# g.ndata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
# sg.ndata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if
self
.
is_pinned
()
and
F
.
context
(
val
)
==
'cpu'
and
not
F
.
is_pinned
(
val
):
if
(
raise
DGLError
(
'Pinned graph requires the node data to be pinned as well. '
self
.
is_pinned
()
'Please pin the node data before assignment.'
)
and
F
.
context
(
val
)
==
"cpu"
and
not
F
.
is_pinned
(
val
)
):
raise
DGLError
(
"Pinned graph requires the node data to be pinned as well. "
"Please pin the node data before assignment."
)
if
is_all
(
u
):
if
is_all
(
u
):
self
.
_node_frames
[
ntid
].
update
(
data
)
self
.
_node_frames
[
ntid
].
update
(
data
)
...
@@ -4070,7 +4330,7 @@ class DGLGraph(object):
...
@@ -4070,7 +4330,7 @@ class DGLGraph(object):
if
is_all
(
u
):
if
is_all
(
u
):
return
self
.
_node_frames
[
ntid
]
return
self
.
_node_frames
[
ntid
]
else
:
else
:
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
return
self
.
_node_frames
[
ntid
].
subframe
(
u
)
return
self
.
_node_frames
[
ntid
].
subframe
(
u
)
def
_pop_n_repr
(
self
,
ntid
,
key
):
def
_pop_n_repr
(
self
,
ntid
,
key
):
...
@@ -4116,12 +4376,14 @@ class DGLGraph(object):
...
@@ -4116,12 +4376,14 @@ class DGLGraph(object):
"""
"""
# parse argument
# parse argument
if
not
is_all
(
edges
):
if
not
is_all
(
edges
):
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
data
):
if
not
utils
.
is_dict_like
(
data
):
raise
DGLError
(
'Expect dictionary type for feature data.'
raise
DGLError
(
' Got "%s" instead.'
%
type
(
data
))
"Expect dictionary type for feature data."
' Got "%s" instead.'
%
type
(
data
)
)
if
is_all
(
edges
):
if
is_all
(
edges
):
num_edges
=
self
.
_graph
.
number_of_edges
(
etid
)
num_edges
=
self
.
_graph
.
number_of_edges
(
etid
)
...
@@ -4130,21 +4392,31 @@ class DGLGraph(object):
...
@@ -4130,21 +4392,31 @@ class DGLGraph(object):
for
key
,
val
in
data
.
items
():
for
key
,
val
in
data
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!=
num_edges
:
if
nfeats
!=
num_edges
:
raise
DGLError
(
'Expect number of features to match number of edges.'
raise
DGLError
(
' Got %d and %d instead.'
%
(
nfeats
,
num_edges
))
"Expect number of features to match number of edges."
" Got %d and %d instead."
%
(
nfeats
,
num_edges
)
)
if
F
.
context
(
val
)
!=
self
.
device
:
if
F
.
context
(
val
)
!=
self
.
device
:
raise
DGLError
(
'Cannot assign edge feature "{}" on device {} to a graph on'
raise
DGLError
(
' device {}. Call DGLGraph.to() to copy the graph to the'
'Cannot assign edge feature "{}" on device {} to a graph on'
' same device.'
.
format
(
key
,
F
.
context
(
val
),
self
.
device
))
" device {}. Call DGLGraph.to() to copy the graph to the"
" same device."
.
format
(
key
,
F
.
context
(
val
),
self
.
device
)
)
# To prevent users from doing things like:
# To prevent users from doing things like:
#
#
# g.pin_memory_()
# g.pin_memory_()
# g.edata['x'] = torch.randn(...)
# g.edata['x'] = torch.randn(...)
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg = g.sample_neighbors(torch.LongTensor([...]).cuda())
# sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
# sg.edata['x'] # Becomes a CPU tensor even if sg is on GPU due to lazy slicing
if
self
.
is_pinned
()
and
F
.
context
(
val
)
==
'cpu'
and
not
F
.
is_pinned
(
val
):
if
(
raise
DGLError
(
'Pinned graph requires the edge data to be pinned as well. '
self
.
is_pinned
()
'Please pin the edge data before assignment.'
)
and
F
.
context
(
val
)
==
"cpu"
and
not
F
.
is_pinned
(
val
)
):
raise
DGLError
(
"Pinned graph requires the edge data to be pinned as well. "
"Please pin the edge data before assignment."
)
# set
# set
if
is_all
(
edges
):
if
is_all
(
edges
):
...
@@ -4172,7 +4444,7 @@ class DGLGraph(object):
...
@@ -4172,7 +4444,7 @@ class DGLGraph(object):
if
is_all
(
edges
):
if
is_all
(
edges
):
return
self
.
_edge_frames
[
etid
]
return
self
.
_edge_frames
[
etid
]
else
:
else
:
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
return
self
.
_edge_frames
[
etid
].
subframe
(
eid
)
return
self
.
_edge_frames
[
etid
].
subframe
(
eid
)
def
_pop_e_repr
(
self
,
etid
,
key
):
def
_pop_e_repr
(
self
,
etid
,
key
):
...
@@ -4256,7 +4528,7 @@ class DGLGraph(object):
...
@@ -4256,7 +4528,7 @@ class DGLGraph(object):
if
is_all
(
v
):
if
is_all
(
v
):
v_id
=
self
.
nodes
(
ntype
)
v_id
=
self
.
nodes
(
ntype
)
else
:
else
:
v_id
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v_id
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
ndata
=
core
.
invoke_node_udf
(
self
,
v_id
,
ntype
,
func
,
orig_nid
=
v_id
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
self
.
_set_n_repr
(
ntid
,
v
,
ndata
)
...
@@ -4348,14 +4620,16 @@ class DGLGraph(object):
...
@@ -4348,14 +4620,16 @@ class DGLGraph(object):
g
=
self
if
etype
is
None
else
self
[
etype
]
g
=
self
if
etype
is
None
else
self
[
etype
]
else
:
# heterogeneous graph with number of relation types > 1
else
:
# heterogeneous graph with number of relation types > 1
if
not
core
.
is_builtin
(
func
):
if
not
core
.
is_builtin
(
func
):
raise
DGLError
(
"User defined functions are not yet "
raise
DGLError
(
"User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead."
)
"Please use (apply_edges(func), etype = rel) instead."
)
g
=
self
g
=
self
if
is_all
(
edges
):
if
is_all
(
edges
):
eid
=
ALL
eid
=
ALL
else
:
else
:
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
if
core
.
is_builtin
(
func
):
if
core
.
is_builtin
(
func
):
if
not
is_all
(
eid
):
if
not
is_all
(
eid
):
g
=
g
.
edge_subgraph
(
eid
,
relabel_nodes
=
False
)
g
=
g
.
edge_subgraph
(
eid
,
relabel_nodes
=
False
)
...
@@ -4375,12 +4649,9 @@ class DGLGraph(object):
...
@@ -4375,12 +4649,9 @@ class DGLGraph(object):
edata_tensor
[
key
]
=
out_tensor_tuples
[
etid
]
edata_tensor
[
key
]
=
out_tensor_tuples
[
etid
]
self
.
_set_e_repr
(
etid
,
eid
,
edata_tensor
)
self
.
_set_e_repr
(
etid
,
eid
,
edata_tensor
)
def
send_and_recv
(
self
,
def
send_and_recv
(
edges
,
self
,
edges
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
message_func
,
):
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Send messages along the specified edges and reduce them on
"""Send messages along the specified edges and reduce them on
the destination nodes to update their features.
the destination nodes to update their features.
...
@@ -4493,7 +4764,7 @@ class DGLGraph(object):
...
@@ -4493,7 +4764,7 @@ class DGLGraph(object):
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
etype
=
self
.
canonical_etypes
[
etid
]
etype
=
self
.
canonical_etypes
[
etid
]
# edge IDs
# edge IDs
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
'
edges
'
)
eid
=
utils
.
parse_edges_arg_to_eid
(
self
,
edges
,
etid
,
"
edges
"
)
if
len
(
eid
)
==
0
:
if
len
(
eid
)
==
0
:
# no computation
# no computation
return
return
...
@@ -4502,15 +4773,13 @@ class DGLGraph(object):
...
@@ -4502,15 +4773,13 @@ class DGLGraph(object):
g
=
self
if
etype
is
None
else
self
[
etype
]
g
=
self
if
etype
is
None
else
self
[
etype
]
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
u
,
v
,
eid
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
u
,
v
,
eid
)
ndata
=
core
.
message_passing
(
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
pull
(
self
,
def
pull
(
v
,
self
,
v
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
message_func
,
):
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Pull messages from the specified node(s)' predecessors along the
"""Pull messages from the specified node(s)' predecessors along the
specified edge type, aggregate them to update the node features.
specified edge type, aggregate them to update the node features.
...
@@ -4588,7 +4857,7 @@ class DGLGraph(object):
...
@@ -4588,7 +4857,7 @@ class DGLGraph(object):
[1.],
[1.],
[1.]])
[1.]])
"""
"""
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
len
(
v
)
==
0
:
if
len
(
v
)
==
0
:
# no computation
# no computation
return
return
...
@@ -4597,18 +4866,18 @@ class DGLGraph(object):
...
@@ -4597,18 +4866,18 @@ class DGLGraph(object):
etype
=
self
.
canonical_etypes
[
etid
]
etype
=
self
.
canonical_etypes
[
etid
]
g
=
self
if
etype
is
None
else
self
[
etype
]
g
=
self
if
etype
is
None
else
self
[
etype
]
# call message passing on subgraph
# call message passing on subgraph
src
,
dst
,
eid
=
g
.
in_edges
(
v
,
form
=
'all'
)
src
,
dst
,
eid
=
g
.
in_edges
(
v
,
form
=
"all"
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
)
ndata
=
core
.
message_passing
(
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
push
(
self
,
def
push
(
u
,
self
,
u
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
message_func
,
):
reduce_func
,
apply_node_func
=
None
,
etype
=
None
):
"""Send message from the specified node(s) to their successors
"""Send message from the specified node(s) to their successors
along the specified edge type and update their node features.
along the specified edge type and update their node features.
...
@@ -4679,14 +4948,14 @@ class DGLGraph(object):
...
@@ -4679,14 +4948,14 @@ class DGLGraph(object):
[0.],
[0.],
[0.]])
[0.]])
"""
"""
edges
=
self
.
out_edges
(
u
,
form
=
'
eid
'
,
etype
=
etype
)
edges
=
self
.
out_edges
(
u
,
form
=
"
eid
"
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
self
.
send_and_recv
(
edges
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
def
update_all
(
self
,
)
message_func
,
reduce_func
,
def
update_all
(
apply_node_func
=
None
,
self
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
etype
=
None
etype
=
None
):
):
"""Send messages along all the edges of the specified type
"""Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type.
and update all the nodes of the corresponding destination type.
...
@@ -4778,23 +5047,37 @@ class DGLGraph(object):
...
@@ -4778,23 +5047,37 @@ class DGLGraph(object):
etype
=
self
.
canonical_etypes
[
etid
]
etype
=
self
.
canonical_etypes
[
etid
]
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
g
=
self
if
etype
is
None
else
self
[
etype
]
g
=
self
if
etype
is
None
else
self
[
etype
]
ndata
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
ndata
=
core
.
message_passing
(
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
'min'
,
'max'
]
and
ndata
:
g
,
message_func
,
reduce_func
,
apply_node_func
)
if
(
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
"min"
,
"max"
]
and
ndata
):
# Replace infinity with zero for isolated nodes
# Replace infinity with zero for isolated nodes
key
=
list
(
ndata
.
keys
())[
0
]
key
=
list
(
ndata
.
keys
())[
0
]
ndata
[
key
]
=
F
.
replace_inf_with_zero
(
ndata
[
key
])
ndata
[
key
]
=
F
.
replace_inf_with_zero
(
ndata
[
key
])
self
.
_set_n_repr
(
dtid
,
ALL
,
ndata
)
self
.
_set_n_repr
(
dtid
,
ALL
,
ndata
)
else
:
# heterogeneous graph with number of relation types > 1
else
:
# heterogeneous graph with number of relation types > 1
if
not
core
.
is_builtin
(
message_func
)
or
not
core
.
is_builtin
(
reduce_func
):
if
not
core
.
is_builtin
(
message_func
)
or
not
core
.
is_builtin
(
raise
DGLError
(
"User defined functions are not yet "
reduce_func
):
raise
DGLError
(
"User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead."
)
"Please use multi_update_all instead."
if
reduce_func
.
name
in
[
'mean'
]:
)
raise
NotImplementedError
(
"Cannot set both intra-type and inter-type reduce "
if
reduce_func
.
name
in
[
"mean"
]:
raise
NotImplementedError
(
"Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead."
)
"multi_update_all instead."
)
g
=
self
g
=
self
all_out
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
all_out
=
core
.
message_passing
(
g
,
message_func
,
reduce_func
,
apply_node_func
)
key
=
list
(
all_out
.
keys
())[
0
]
key
=
list
(
all_out
.
keys
())[
0
]
out_tensor_tuples
=
all_out
[
key
]
out_tensor_tuples
=
all_out
[
key
]
...
@@ -4802,7 +5085,10 @@ class DGLGraph(object):
...
@@ -4802,7 +5085,10 @@ class DGLGraph(object):
for
_
,
_
,
dsttype
in
g
.
canonical_etypes
:
for
_
,
_
,
dsttype
in
g
.
canonical_etypes
:
dtid
=
g
.
get_ntype_id
(
dsttype
)
dtid
=
g
.
get_ntype_id
(
dsttype
)
dst_tensor
[
key
]
=
out_tensor_tuples
[
dtid
]
dst_tensor
[
key
]
=
out_tensor_tuples
[
dtid
]
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
'min'
,
'max'
]:
if
core
.
is_builtin
(
reduce_func
)
and
reduce_func
.
name
in
[
"min"
,
"max"
,
]:
dst_tensor
[
key
]
=
F
.
replace_inf_with_zero
(
dst_tensor
[
key
])
dst_tensor
[
key
]
=
F
.
replace_inf_with_zero
(
dst_tensor
[
key
])
self
.
_node_frames
[
dtid
].
update
(
dst_tensor
)
self
.
_node_frames
[
dtid
].
update
(
dst_tensor
)
...
@@ -4902,36 +5188,44 @@ class DGLGraph(object):
...
@@ -4902,36 +5188,44 @@ class DGLGraph(object):
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
_
,
dtid
=
self
.
_graph
.
metagraph
.
find_edge
(
etid
)
args
=
pad_tuple
(
args
,
3
)
args
=
pad_tuple
(
args
,
3
)
if
args
is
None
:
if
args
is
None
:
raise
DGLError
(
'Invalid arguments for edge type "{}". Should be '
raise
DGLError
(
'(msg_func, reduce_func, [apply_node_func])'
.
format
(
etype
))
'Invalid arguments for edge type "{}". Should be '
"(msg_func, reduce_func, [apply_node_func])"
.
format
(
etype
)
)
mfunc
,
rfunc
,
afunc
=
args
mfunc
,
rfunc
,
afunc
=
args
g
=
self
if
etype
is
None
else
self
[
etype
]
g
=
self
if
etype
is
None
else
self
[
etype
]
all_out
[
dtid
].
append
(
core
.
message_passing
(
g
,
mfunc
,
rfunc
,
afunc
))
all_out
[
dtid
].
append
(
core
.
message_passing
(
g
,
mfunc
,
rfunc
,
afunc
))
merge_order
[
dtid
].
append
(
etid
)
# use edge type id as merge order hint
merge_order
[
dtid
].
append
(
etid
)
# use edge type id as merge order hint
for
dtid
,
frames
in
all_out
.
items
():
for
dtid
,
frames
in
all_out
.
items
():
# merge by cross_reducer
# merge by cross_reducer
out
=
reduce_dict_data
(
frames
,
cross_reducer
,
merge_order
[
dtid
])
out
=
reduce_dict_data
(
frames
,
cross_reducer
,
merge_order
[
dtid
])
# Replace infinity with zero for isolated nodes when reducer is min/max
# Replace infinity with zero for isolated nodes when reducer is min/max
if
core
.
is_builtin
(
rfunc
)
and
rfunc
.
name
in
[
'
min
'
,
'
max
'
]:
if
core
.
is_builtin
(
rfunc
)
and
rfunc
.
name
in
[
"
min
"
,
"
max
"
]:
key
=
list
(
out
.
keys
())[
0
]
key
=
list
(
out
.
keys
())[
0
]
out
[
key
]
=
F
.
replace_inf_with_zero
(
out
[
key
])
if
out
[
key
]
is
not
None
else
None
out
[
key
]
=
(
F
.
replace_inf_with_zero
(
out
[
key
])
if
out
[
key
]
is
not
None
else
None
)
self
.
_node_frames
[
dtid
].
update
(
out
)
self
.
_node_frames
[
dtid
].
update
(
out
)
# apply
# apply
if
apply_node_func
is
not
None
:
if
apply_node_func
is
not
None
:
self
.
apply_nodes
(
apply_node_func
,
ALL
,
self
.
ntypes
[
dtid
])
self
.
apply_nodes
(
apply_node_func
,
ALL
,
self
.
ntypes
[
dtid
])
#################################################################
#################################################################
# Message propagation
# Message propagation
#################################################################
#################################################################
def
prop_nodes
(
self
,
def
prop_nodes
(
self
,
nodes_generator
,
nodes_generator
,
message_func
,
message_func
,
reduce_func
,
reduce_func
,
apply_node_func
=
None
,
apply_node_func
=
None
,
etype
=
None
):
etype
=
None
,
):
"""Propagate messages using graph traversal by sequentially triggering
"""Propagate messages using graph traversal by sequentially triggering
:func:`pull()` on nodes.
:func:`pull()` on nodes.
...
@@ -4987,14 +5281,22 @@ class DGLGraph(object):
...
@@ -4987,14 +5281,22 @@ class DGLGraph(object):
prop_edges
prop_edges
"""
"""
for
node_frontier
in
nodes_generator
:
for
node_frontier
in
nodes_generator
:
self
.
pull
(
node_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
)
self
.
pull
(
node_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
,
)
def
prop_edges
(
self
,
def
prop_edges
(
self
,
edges_generator
,
edges_generator
,
message_func
,
message_func
,
reduce_func
,
reduce_func
,
apply_node_func
=
None
,
apply_node_func
=
None
,
etype
=
None
):
etype
=
None
,
):
"""Propagate messages using graph traversal by sequentially triggering
"""Propagate messages using graph traversal by sequentially triggering
:func:`send_and_recv()` on edges.
:func:`send_and_recv()` on edges.
...
@@ -5051,8 +5353,13 @@ class DGLGraph(object):
...
@@ -5051,8 +5353,13 @@ class DGLGraph(object):
prop_nodes
prop_nodes
"""
"""
for
edge_frontier
in
edges_generator
:
for
edge_frontier
in
edges_generator
:
self
.
send_and_recv
(
edge_frontier
,
message_func
,
reduce_func
,
self
.
send_and_recv
(
apply_node_func
,
etype
=
etype
)
edge_frontier
,
message_func
,
reduce_func
,
apply_node_func
,
etype
=
etype
,
)
#################################################################
#################################################################
# Misc
# Misc
...
@@ -5127,14 +5434,16 @@ class DGLGraph(object):
...
@@ -5127,14 +5434,16 @@ class DGLGraph(object):
"""
"""
if
is_all
(
nodes
):
if
is_all
(
nodes
):
nodes
=
self
.
nodes
(
ntype
)
nodes
=
self
.
nodes
(
ntype
)
v
=
utils
.
prepare_tensor
(
self
,
nodes
,
'
nodes
'
)
v
=
utils
.
prepare_tensor
(
self
,
nodes
,
"
nodes
"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
ntype
),
dim
=
0
))
!=
len
(
v
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
ntype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
'
v contains invalid node IDs
'
)
raise
DGLError
(
"
v contains invalid node IDs
"
)
with
self
.
local_scope
():
with
self
.
local_scope
():
self
.
apply_nodes
(
lambda
nbatch
:
{
'_mask'
:
predicate
(
nbatch
)},
nodes
,
ntype
)
self
.
apply_nodes
(
lambda
nbatch
:
{
"_mask"
:
predicate
(
nbatch
)},
nodes
,
ntype
)
ntype
=
self
.
ntypes
[
0
]
if
ntype
is
None
else
ntype
ntype
=
self
.
ntypes
[
0
]
if
ntype
is
None
else
ntype
mask
=
self
.
nodes
[
ntype
].
data
[
'
_mask
'
]
mask
=
self
.
nodes
[
ntype
].
data
[
"
_mask
"
]
if
is_all
(
nodes
):
if
is_all
(
nodes
):
return
F
.
nonzero_1d
(
mask
)
return
F
.
nonzero_1d
(
mask
)
else
:
else
:
...
@@ -5221,34 +5530,40 @@ class DGLGraph(object):
...
@@ -5221,34 +5530,40 @@ class DGLGraph(object):
elif
isinstance
(
edges
,
tuple
):
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
,
v
=
edges
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
srctype
,
_
,
dsttype
=
self
.
to_canonical_etype
(
etype
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
'u'
)
u
=
utils
.
prepare_tensor
(
self
,
u
,
"u"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
))
!=
len
(
u
):
if
F
.
as_scalar
(
raise
DGLError
(
'edges[0] contains invalid node IDs'
)
F
.
sum
(
self
.
has_nodes
(
u
,
ntype
=
srctype
),
dim
=
0
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
'v'
)
)
!=
len
(
u
):
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
))
!=
len
(
v
):
raise
DGLError
(
"edges[0] contains invalid node IDs"
)
raise
DGLError
(
'edges[1] contains invalid node IDs'
)
v
=
utils
.
prepare_tensor
(
self
,
v
,
"v"
)
if
F
.
as_scalar
(
F
.
sum
(
self
.
has_nodes
(
v
,
ntype
=
dsttype
),
dim
=
0
)
)
!=
len
(
v
):
raise
DGLError
(
"edges[1] contains invalid node IDs"
)
elif
isinstance
(
edges
,
Iterable
)
or
F
.
is_tensor
(
edges
):
elif
isinstance
(
edges
,
Iterable
)
or
F
.
is_tensor
(
edges
):
edges
=
utils
.
prepare_tensor
(
self
,
edges
,
'
edges
'
)
edges
=
utils
.
prepare_tensor
(
self
,
edges
,
"
edges
"
)
min_eid
=
F
.
as_scalar
(
F
.
min
(
edges
,
0
))
min_eid
=
F
.
as_scalar
(
F
.
min
(
edges
,
0
))
if
len
(
edges
)
>
0
>
min_eid
:
if
len
(
edges
)
>
0
>
min_eid
:
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
min_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
min_eid
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
edges
,
0
))
max_eid
=
F
.
as_scalar
(
F
.
max
(
edges
,
0
))
if
len
(
edges
)
>
0
and
max_eid
>=
self
.
num_edges
(
etype
):
if
len
(
edges
)
>
0
and
max_eid
>=
self
.
num_edges
(
etype
):
raise
DGLError
(
'
Invalid edge ID {:d}
'
.
format
(
max_eid
))
raise
DGLError
(
"
Invalid edge ID {:d}
"
.
format
(
max_eid
))
else
:
else
:
raise
ValueError
(
'
Unsupported type of edges:
'
,
type
(
edges
))
raise
ValueError
(
"
Unsupported type of edges:
"
,
type
(
edges
))
with
self
.
local_scope
():
with
self
.
local_scope
():
self
.
apply_edges
(
lambda
ebatch
:
{
'_mask'
:
predicate
(
ebatch
)},
edges
,
etype
)
self
.
apply_edges
(
lambda
ebatch
:
{
"_mask"
:
predicate
(
ebatch
)},
edges
,
etype
)
etype
=
self
.
canonical_etypes
[
0
]
if
etype
is
None
else
etype
etype
=
self
.
canonical_etypes
[
0
]
if
etype
is
None
else
etype
mask
=
self
.
edges
[
etype
].
data
[
'
_mask
'
]
mask
=
self
.
edges
[
etype
].
data
[
"
_mask
"
]
if
is_all
(
edges
):
if
is_all
(
edges
):
return
F
.
nonzero_1d
(
mask
)
return
F
.
nonzero_1d
(
mask
)
else
:
else
:
if
isinstance
(
edges
,
tuple
):
if
isinstance
(
edges
,
tuple
):
e
=
self
.
edge_ids
(
edges
[
0
],
edges
[
1
],
etype
=
etype
)
e
=
self
.
edge_ids
(
edges
[
0
],
edges
[
1
],
etype
=
etype
)
else
:
else
:
e
=
utils
.
prepare_tensor
(
self
,
edges
,
'
edges
'
)
e
=
utils
.
prepare_tensor
(
self
,
edges
,
"
edges
"
)
return
F
.
boolean_mask
(
e
,
F
.
gather_row
(
mask
,
e
))
return
F
.
boolean_mask
(
e
,
F
.
gather_row
(
mask
,
e
))
@
property
@
property
...
@@ -5347,12 +5662,16 @@ class DGLGraph(object):
...
@@ -5347,12 +5662,16 @@ class DGLGraph(object):
# 2. Copy misc info
# 2. Copy misc info
if
self
.
_batch_num_nodes
is
not
None
:
if
self
.
_batch_num_nodes
is
not
None
:
new_bnn
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
new_bnn
=
{
for
k
,
num
in
self
.
_batch_num_nodes
.
items
()}
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_nodes
.
items
()
}
ret
.
_batch_num_nodes
=
new_bnn
ret
.
_batch_num_nodes
=
new_bnn
if
self
.
_batch_num_edges
is
not
None
:
if
self
.
_batch_num_edges
is
not
None
:
new_bne
=
{
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
new_bne
=
{
for
k
,
num
in
self
.
_batch_num_edges
.
items
()}
k
:
F
.
copy_to
(
num
,
device
,
**
kwargs
)
for
k
,
num
in
self
.
_batch_num_edges
.
items
()
}
ret
.
_batch_num_edges
=
new_bne
ret
.
_batch_num_edges
=
new_bne
return
ret
return
ret
...
@@ -5432,8 +5751,10 @@ class DGLGraph(object):
...
@@ -5432,8 +5751,10 @@ class DGLGraph(object):
tensor([0, 1, 1])
tensor([0, 1, 1])
"""
"""
if
not
self
.
_graph
.
is_pinned
():
if
not
self
.
_graph
.
is_pinned
():
if
F
.
device_type
(
self
.
device
)
!=
'cpu'
:
if
F
.
device_type
(
self
.
device
)
!=
"cpu"
:
raise
DGLError
(
"The graph structure must be on CPU to be pinned."
)
raise
DGLError
(
"The graph structure must be on CPU to be pinned."
)
self
.
_graph
.
pin_memory_
()
self
.
_graph
.
pin_memory_
()
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
for
col
in
frame
.
_columns
.
values
():
for
col
in
frame
.
_columns
.
values
():
...
@@ -5484,9 +5805,9 @@ class DGLGraph(object):
...
@@ -5484,9 +5805,9 @@ class DGLGraph(object):
DGLGraph
DGLGraph
self.
self.
"""
"""
if
F
.
get_preferred_backend
()
!=
'
pytorch
'
:
if
F
.
get_preferred_backend
()
!=
"
pytorch
"
:
raise
DGLError
(
"record_stream only support the PyTorch backend."
)
raise
DGLError
(
"record_stream only support the PyTorch backend."
)
if
F
.
device_type
(
self
.
device
)
!=
'
cuda
'
:
if
F
.
device_type
(
self
.
device
)
!=
"
cuda
"
:
raise
DGLError
(
"The graph must be on GPU to be recorded."
)
raise
DGLError
(
"The graph must be on GPU to be recorded."
)
self
.
_graph
.
record_stream
(
stream
)
self
.
_graph
.
record_stream
(
stream
)
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
for
frame
in
itertools
.
chain
(
self
.
_node_frames
,
self
.
_edge_frames
):
...
@@ -5510,15 +5831,24 @@ class DGLGraph(object):
...
@@ -5510,15 +5831,24 @@ class DGLGraph(object):
# Clone the graph structure
# Clone the graph structure
meta_edges
=
[]
meta_edges
=
[]
for
s_ntype
,
_
,
d_ntype
in
self
.
canonical_etypes
:
for
s_ntype
,
_
,
d_ntype
in
self
.
canonical_etypes
:
meta_edges
.
append
((
self
.
get_ntype_id
(
s_ntype
),
self
.
get_ntype_id
(
d_ntype
)))
meta_edges
.
append
(
(
self
.
get_ntype_id
(
s_ntype
),
self
.
get_ntype_id
(
d_ntype
))
)
metagraph
=
graph_index
.
from_edge_list
(
meta_edges
,
True
)
metagraph
=
graph_index
.
from_edge_list
(
meta_edges
,
True
)
# rebuild graph idx
# rebuild graph idx
num_nodes_per_type
=
[
self
.
number_of_nodes
(
c_ntype
)
for
c_ntype
in
self
.
ntypes
]
num_nodes_per_type
=
[
relation_graphs
=
[
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
self
.
number_of_nodes
(
c_ntype
)
for
c_ntype
in
self
.
ntypes
for
c_etype
in
self
.
canonical_etypes
]
]
relation_graphs
=
[
self
.
_graph
.
get_relation_graph
(
self
.
get_etype_id
(
c_etype
))
for
c_etype
in
self
.
canonical_etypes
]
ret
.
_graph
=
heterograph_index
.
create_heterograph_from_relations
(
ret
.
_graph
=
heterograph_index
.
create_heterograph_from_relations
(
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
))
metagraph
,
relation_graphs
,
utils
.
toindex
(
num_nodes_per_type
,
"int64"
),
)
# Clone the frames
# Clone the frames
ret
.
_node_frames
=
[
fr
.
clone
()
for
fr
in
self
.
_node_frames
]
ret
.
_node_frames
=
[
fr
.
clone
()
for
fr
in
self
.
_node_frames
]
...
@@ -5819,7 +6149,7 @@ class DGLGraph(object):
...
@@ -5819,7 +6149,7 @@ class DGLGraph(object):
return
ret
return
ret
# TODO: Formats should not be specified, just saving all the materialized formats
# TODO: Formats should not be specified, just saving all the materialized formats
def
shared_memory
(
self
,
name
,
formats
=
(
'
coo
'
,
'
csr
'
,
'
csc
'
)):
def
shared_memory
(
self
,
name
,
formats
=
(
"
coo
"
,
"
csr
"
,
"
csc
"
)):
"""Return a copy of this graph in shared memory, without node data or edge data.
"""Return a copy of this graph in shared memory, without node data or edge data.
It moves the graph index to shared memory and returns a DGLGraph object which
It moves the graph index to shared memory and returns a DGLGraph object which
...
@@ -5843,11 +6173,16 @@ class DGLGraph(object):
...
@@ -5843,11 +6173,16 @@ class DGLGraph(object):
if
isinstance
(
formats
,
str
):
if
isinstance
(
formats
,
str
):
formats
=
[
formats
]
formats
=
[
formats
]
for
fmt
in
formats
:
for
fmt
in
formats
:
assert
fmt
in
(
"coo"
,
"csr"
,
"csc"
),
'{} is not coo, csr or csc'
.
format
(
fmt
)
assert
fmt
in
(
gidx
=
self
.
_graph
.
shared_memory
(
name
,
self
.
ntypes
,
self
.
etypes
,
formats
)
"coo"
,
"csr"
,
"csc"
,
),
"{} is not coo, csr or csc"
.
format
(
fmt
)
gidx
=
self
.
_graph
.
shared_memory
(
name
,
self
.
ntypes
,
self
.
etypes
,
formats
)
return
DGLGraph
(
gidx
,
self
.
ntypes
,
self
.
etypes
)
return
DGLGraph
(
gidx
,
self
.
ntypes
,
self
.
etypes
)
def
long
(
self
):
def
long
(
self
):
"""Cast the graph to one with idtype int64
"""Cast the graph to one with idtype int64
...
@@ -5948,10 +6283,12 @@ class DGLGraph(object):
...
@@ -5948,10 +6283,12 @@ class DGLGraph(object):
"""
"""
return
self
.
astype
(
F
.
int32
)
return
self
.
astype
(
F
.
int32
)
############################################################
############################################################
# Internal APIs
# Internal APIs
############################################################
############################################################
def
make_canonical_etypes
(
etypes
,
ntypes
,
metagraph
):
def
make_canonical_etypes
(
etypes
,
ntypes
,
metagraph
):
"""Internal function to convert etype name to (srctype, etype, dsttype)
"""Internal function to convert etype name to (srctype, etype, dsttype)
...
@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
...
@@ -5970,19 +6307,29 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
"""
"""
# sanity check
# sanity check
if
len
(
etypes
)
!=
metagraph
.
number_of_edges
():
if
len
(
etypes
)
!=
metagraph
.
number_of_edges
():
raise
DGLError
(
'Length of edge type list must match the number of '
raise
DGLError
(
'edges in the metagraph. {} vs {}'
.
format
(
"Length of edge type list must match the number of "
len
(
etypes
),
metagraph
.
number_of_edges
()))
"edges in the metagraph. {} vs {}"
.
format
(
len
(
etypes
),
metagraph
.
number_of_edges
()
)
)
if
len
(
ntypes
)
!=
metagraph
.
number_of_nodes
():
if
len
(
ntypes
)
!=
metagraph
.
number_of_nodes
():
raise
DGLError
(
'Length of nodes type list must match the number of '
raise
DGLError
(
'nodes in the metagraph. {} vs {}'
.
format
(
"Length of nodes type list must match the number of "
len
(
ntypes
),
metagraph
.
number_of_nodes
()))
"nodes in the metagraph. {} vs {}"
.
format
(
if
(
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
):
len
(
ntypes
),
metagraph
.
number_of_nodes
()
)
)
if
len
(
etypes
)
==
1
and
len
(
ntypes
)
==
1
:
return
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
return
[(
ntypes
[
0
],
etypes
[
0
],
ntypes
[
0
])]
src
,
dst
,
eid
=
metagraph
.
edges
(
order
=
"eid"
)
src
,
dst
,
eid
=
metagraph
.
edges
(
order
=
"eid"
)
rst
=
[(
ntypes
[
sid
],
etypes
[
eid
],
ntypes
[
did
])
for
sid
,
did
,
eid
in
zip
(
src
,
dst
,
eid
)]
rst
=
[
(
ntypes
[
sid
],
etypes
[
eid
],
ntypes
[
did
])
for
sid
,
did
,
eid
in
zip
(
src
,
dst
,
eid
)
]
return
rst
return
rst
def
find_src_dst_ntypes
(
ntypes
,
metagraph
):
def
find_src_dst_ntypes
(
ntypes
,
metagraph
):
"""Internal function to split ntypes into SRC and DST categories.
"""Internal function to split ntypes into SRC and DST categories.
...
@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph):
...
@@ -6011,10 +6358,11 @@ def find_src_dst_ntypes(ntypes, metagraph):
return
None
return
None
else
:
else
:
src
,
dst
=
ret
src
,
dst
=
ret
srctypes
=
{
ntypes
[
tid
]
:
tid
for
tid
in
src
}
srctypes
=
{
ntypes
[
tid
]:
tid
for
tid
in
src
}
dsttypes
=
{
ntypes
[
tid
]
:
tid
for
tid
in
dst
}
dsttypes
=
{
ntypes
[
tid
]:
tid
for
tid
in
dst
}
return
srctypes
,
dsttypes
return
srctypes
,
dsttypes
def
pad_tuple
(
tup
,
length
,
pad_val
=
None
):
def
pad_tuple
(
tup
,
length
,
pad_val
=
None
):
"""Pad the given tuple to the given length.
"""Pad the given tuple to the given length.
...
@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None):
...
@@ -6022,7 +6370,7 @@ def pad_tuple(tup, length, pad_val=None):
Return None if pad fails.
Return None if pad fails.
"""
"""
if
not
isinstance
(
tup
,
tuple
):
if
not
isinstance
(
tup
,
tuple
):
tup
=
(
tup
,
)
tup
=
(
tup
,)
if
len
(
tup
)
>
length
:
if
len
(
tup
)
>
length
:
return
None
return
None
elif
len
(
tup
)
==
length
:
elif
len
(
tup
)
==
length
:
...
@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None):
...
@@ -6030,6 +6378,7 @@ def pad_tuple(tup, length, pad_val=None):
else
:
else
:
return
tup
+
(
pad_val
,)
*
(
length
-
len
(
tup
))
return
tup
+
(
pad_val
,)
*
(
length
-
len
(
tup
))
def
reduce_dict_data
(
frames
,
reducer
,
order
=
None
):
def
reduce_dict_data
(
frames
,
reducer
,
order
=
None
):
"""Merge tensor dictionaries into one. Resolve conflict fields using reducer.
"""Merge tensor dictionaries into one. Resolve conflict fields using reducer.
...
@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None):
...
@@ -6054,27 +6403,33 @@ def reduce_dict_data(frames, reducer, order=None):
dict[str, Tensor]
dict[str, Tensor]
Merged frame
Merged frame
"""
"""
if
len
(
frames
)
==
1
and
reducer
!=
'
stack
'
:
if
len
(
frames
)
==
1
and
reducer
!=
"
stack
"
:
# Directly return the only one input. Stack reducer requires
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
# modifying tensor shape.
return
frames
[
0
]
return
frames
[
0
]
if
callable
(
reducer
):
if
callable
(
reducer
):
merger
=
reducer
merger
=
reducer
elif
reducer
==
'
stack
'
:
elif
reducer
==
"
stack
"
:
# Stack order does not matter. However, it must be consistent!
# Stack order does not matter. However, it must be consistent!
if
order
:
if
order
:
assert
len
(
order
)
==
len
(
frames
)
assert
len
(
order
)
==
len
(
frames
)
sorted_with_key
=
sorted
(
zip
(
frames
,
order
),
key
=
lambda
x
:
x
[
1
])
sorted_with_key
=
sorted
(
zip
(
frames
,
order
),
key
=
lambda
x
:
x
[
1
])
frames
=
list
(
zip
(
*
sorted_with_key
))[
0
]
frames
=
list
(
zip
(
*
sorted_with_key
))[
0
]
def
merger
(
flist
):
def
merger
(
flist
):
return
F
.
stack
(
flist
,
1
)
return
F
.
stack
(
flist
,
1
)
else
:
else
:
redfn
=
getattr
(
F
,
reducer
,
None
)
redfn
=
getattr
(
F
,
reducer
,
None
)
if
redfn
is
None
:
if
redfn
is
None
:
raise
DGLError
(
'Invalid cross type reducer. Must be one of '
raise
DGLError
(
'"sum", "max", "min", "mean" or "stack".'
)
"Invalid cross type reducer. Must be one of "
'"sum", "max", "min", "mean" or "stack".'
)
def
merger
(
flist
):
def
merger
(
flist
):
return
redfn
(
F
.
stack
(
flist
,
0
),
0
)
if
len
(
flist
)
>
1
else
flist
[
0
]
return
redfn
(
F
.
stack
(
flist
,
0
),
0
)
if
len
(
flist
)
>
1
else
flist
[
0
]
keys
=
set
()
keys
=
set
()
for
frm
in
frames
:
for
frm
in
frames
:
keys
.
update
(
frm
.
keys
())
keys
.
update
(
frm
.
keys
())
...
@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None):
...
@@ -6087,6 +6442,7 @@ def reduce_dict_data(frames, reducer, order=None):
ret
[
k
]
=
merger
(
flist
)
ret
[
k
]
=
merger
(
flist
)
return
ret
return
ret
def
combine_frames
(
frames
,
ids
,
col_names
=
None
):
def
combine_frames
(
frames
,
ids
,
col_names
=
None
):
"""Merge the frames into one frame, taking the common columns.
"""Merge the frames into one frame, taking the common columns.
...
@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None):
...
@@ -6120,8 +6476,10 @@ def combine_frames(frames, ids, col_names=None):
for
key
,
scheme
in
list
(
schemes
.
items
()):
for
key
,
scheme
in
list
(
schemes
.
items
()):
if
key
in
frame
.
schemes
:
if
key
in
frame
.
schemes
:
if
frame
.
schemes
[
key
]
!=
scheme
:
if
frame
.
schemes
[
key
]
!=
scheme
:
raise
DGLError
(
'Cannot concatenate column %s with shape %s and shape %s'
%
raise
DGLError
(
(
key
,
frame
.
schemes
[
key
],
scheme
))
"Cannot concatenate column %s with shape %s and shape %s"
%
(
key
,
frame
.
schemes
[
key
],
scheme
)
)
else
:
else
:
del
schemes
[
key
]
del
schemes
[
key
]
...
@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None):
...
@@ -6133,6 +6491,7 @@ def combine_frames(frames, ids, col_names=None):
cols
=
{
key
:
F
.
cat
(
to_cat
(
key
),
dim
=
0
)
for
key
in
schemes
}
cols
=
{
key
:
F
.
cat
(
to_cat
(
key
),
dim
=
0
)
for
key
in
schemes
}
return
Frame
(
cols
)
return
Frame
(
cols
)
def
combine_names
(
names
,
ids
=
None
):
def
combine_names
(
names
,
ids
=
None
):
"""Combine the selected names into one new name.
"""Combine the selected names into one new name.
...
@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None):
...
@@ -6148,40 +6507,59 @@ def combine_names(names, ids=None):
str
str
"""
"""
if
ids
is
None
:
if
ids
is
None
:
return
'+'
.
join
(
sorted
(
names
))
return
"+"
.
join
(
sorted
(
names
))
else
:
else
:
selected
=
sorted
([
names
[
i
]
for
i
in
ids
])
selected
=
sorted
([
names
[
i
]
for
i
in
ids
])
return
'+'
.
join
(
selected
)
return
"+"
.
join
(
selected
)
class
DGLBlock
(
DGLGraph
):
class
DGLBlock
(
DGLGraph
):
"""Subclass that signifies the graph is a block created from
"""Subclass that signifies the graph is a block created from
:func:`dgl.to_block`.
:func:`dgl.to_block`.
"""
"""
# (BarclayII) I'm making a subclass because I don't want to make another version of
# (BarclayII) I'm making a subclass because I don't want to make another version of
# serialization that contains the is_block flag.
# serialization that contains the is_block flag.
is_block
=
True
is_block
=
True
def
__repr__
(
self
):
def
__repr__
(
self
):
if
len
(
self
.
srctypes
)
==
1
and
len
(
self
.
dsttypes
)
==
1
and
len
(
self
.
etypes
)
==
1
:
if
(
ret
=
'Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})'
len
(
self
.
srctypes
)
==
1
and
len
(
self
.
dsttypes
)
==
1
and
len
(
self
.
etypes
)
==
1
):
ret
=
"Block(num_src_nodes={srcnode}, num_dst_nodes={dstnode}, num_edges={edge})"
return
ret
.
format
(
return
ret
.
format
(
srcnode
=
self
.
number_of_src_nodes
(),
srcnode
=
self
.
number_of_src_nodes
(),
dstnode
=
self
.
number_of_dst_nodes
(),
dstnode
=
self
.
number_of_dst_nodes
(),
edge
=
self
.
number_of_edges
())
edge
=
self
.
number_of_edges
(),
)
else
:
else
:
ret
=
(
'Block(num_src_nodes={srcnode},
\n
'
ret
=
(
' num_dst_nodes={dstnode},
\n
'
"Block(num_src_nodes={srcnode},
\n
"
' num_edges={edge},
\n
'
" num_dst_nodes={dstnode},
\n
"
' metagraph={meta})'
)
" num_edges={edge},
\n
"
nsrcnode_dict
=
{
ntype
:
self
.
number_of_src_nodes
(
ntype
)
" metagraph={meta})"
for
ntype
in
self
.
srctypes
}
)
ndstnode_dict
=
{
ntype
:
self
.
number_of_dst_nodes
(
ntype
)
nsrcnode_dict
=
{
for
ntype
in
self
.
dsttypes
}
ntype
:
self
.
number_of_src_nodes
(
ntype
)
nedge_dict
=
{
etype
:
self
.
number_of_edges
(
etype
)
for
ntype
in
self
.
srctypes
for
etype
in
self
.
canonical_etypes
}
}
ndstnode_dict
=
{
ntype
:
self
.
number_of_dst_nodes
(
ntype
)
for
ntype
in
self
.
dsttypes
}
nedge_dict
=
{
etype
:
self
.
number_of_edges
(
etype
)
for
etype
in
self
.
canonical_etypes
}
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
meta
=
str
(
self
.
metagraph
().
edges
(
keys
=
True
))
return
ret
.
format
(
return
ret
.
format
(
srcnode
=
nsrcnode_dict
,
dstnode
=
ndstnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
)
srcnode
=
nsrcnode_dict
,
dstnode
=
ndstnode_dict
,
edge
=
nedge_dict
,
meta
=
meta
,
)
def
_create_compute_graph
(
graph
,
u
,
v
,
eid
,
recv_nodes
=
None
):
def
_create_compute_graph
(
graph
,
u
,
v
,
eid
,
recv_nodes
=
None
):
...
@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
...
@@ -6235,17 +6613,32 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
srctype
,
etype
,
dsttype
=
graph
.
canonical_etypes
[
0
]
srctype
,
etype
,
dsttype
=
graph
.
canonical_etypes
[
0
]
# create graph
# create graph
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
hgidx
=
heterograph_index
.
create_unitgraph_from_coo
(
2
,
len
(
unique_src
),
len
(
unique_dst
),
new_u
,
new_v
,
[
'coo'
,
'csr'
,
'csc'
])
2
,
len
(
unique_src
),
len
(
unique_dst
),
new_u
,
new_v
,
[
"coo"
,
"csr"
,
"csc"
]
)
# create frame
# create frame
srcframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
srctype
)].
subframe
(
unique_src
)
srcframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
srctype
)].
subframe
(
unique_src
)
srcframe
[
NID
]
=
unique_src
srcframe
[
NID
]
=
unique_src
dstframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
dsttype
)].
subframe
(
unique_dst
)
dstframe
=
graph
.
_node_frames
[
graph
.
get_ntype_id
(
dsttype
)].
subframe
(
unique_dst
)
dstframe
[
NID
]
=
unique_dst
dstframe
[
NID
]
=
unique_dst
eframe
=
graph
.
_edge_frames
[
0
].
subframe
(
eid
)
eframe
=
graph
.
_edge_frames
[
0
].
subframe
(
eid
)
eframe
[
EID
]
=
eid
eframe
[
EID
]
=
eid
return
DGLGraph
(
hgidx
,
([
srctype
],
[
dsttype
]),
[
etype
],
return
(
DGLGraph
(
hgidx
,
([
srctype
],
[
dsttype
]),
[
etype
],
node_frames
=
[
srcframe
,
dstframe
],
node_frames
=
[
srcframe
,
dstframe
],
edge_frames
=
[
eframe
]),
unique_src
,
unique_dst
,
eid
edge_frames
=
[
eframe
],
),
unique_src
,
unique_dst
,
eid
,
)
_init_api
(
"dgl.heterograph"
)
_init_api
(
"dgl.heterograph"
)
python/dgl/heterograph_index.py
View file @
d1827488
...
@@ -7,12 +7,11 @@ import sys
...
@@ -7,12 +7,11 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
scipy
import
scipy
from
.
import
backend
as
F
from
.
import
backend
as
F
,
utils
from
.
import
utils
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
from
._ffi.object
import
ObjectBase
,
register_object
from
._ffi.object
import
ObjectBase
,
register_object
from
._ffi.streams
import
to_dgl_stream_handle
from
._ffi.streams
import
to_dgl_stream_handle
from
.base
import
DGLError
,
dgl_warning
from
.base
import
dgl_warning
,
DGLError
from
.graph_index
import
from_coo
from
.graph_index
import
from_coo
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment